Graph
Each tfaip graph must inherit GenericGraphBase, or its subclass GraphBase which already encapsulates some optional methods for the graph construction.
GraphBase
An example Multi-Layer-Perceptron (MLP) graph for MNIST is shown in the following code example (excerpt of a Tutorial):
class TutorialGraph(GraphBase[TutorialModelParams]):
def __init__(self, params: TutorialModelParams, name='conv', **kwargs):
super(TutorialGraph, self).__init__(params, name=name, **kwargs)
# Create all layers
self.flatten = Flatten()
self.ff = Dense(128, name='f_ff', activation='relu')
self.logits = Dense(self._params.n_classes, activation=None, name='classify')
def build_graph(self, inputs, training=None):
# Connect all layers and return a dict of the outputs
rescaled_img = K.cast(inputs['img'], dtype='float32') / 255
logits = self.logits(self.ff(self.flatten(rescaled_img)))
pred = K.softmax(logits, axis=-1)
cls = K.argmax(pred, axis=-1)
out = {'pred': pred, 'logits': logits, 'class': cls}
return out
Layers are instantiated in the __init__ function and applied in the build_graph function which is the sole abstract method of GraphBase
GenericGraphBase
The GenericGraphBase class provides more flexibility when creating a graph, including different graphs for training (build_train_graph) and prediction (build_prediction_graph).
This is for example required to implement a sequence-to-sequence model with a decoder that depends on the model (teacher-forcing during training, decoding during prediction).
Furthermode, the method pre_proc_targets can be overwritten to apply some preprocessing on the dataset targets that are then fed into the metrics.