GraphΒΆ
The GraphBase class must be implemented to define the actual network architecture of a scenario.
The sole parameter are the ModelBaseParams, a graph is constructed in ModelBase.create_graph.
GraphBase extends keras.layers.Layer, hence an implementation of call is obligatory to join the layers declared in __init__.
An example Multi-Layer-Perceptron (MLP) graph for MNIST is shown in the following code example (excerpt of a Tutorial):
class TutorialGraph(GraphBase[TutorialModelParams]):
def __init__(self, params: TutorialModelParams, name='conv', **kwargs):
super(TutorialGraph, self).__init__(params, name=name, **kwargs)
# Create all layers
self.flatten = Flatten()
self.ff = Dense(128, name='f_ff', activation='relu')
self.logits = Dense(self._params.n_classes, activation=None, name='classify')
def call(self, inputs, **kwargs):
# Connect all layers and return a dict of the outputs
rescaled_img = K.cast(inputs['img'], dtype='float32') / 255
logits = self.logits(self.ff(self.flatten(rescaled_img)))
pred = K.softmax(logits, axis=-1)
cls = K.argmax(pred, axis=-1)
out = {'pred': pred, 'logits': logits, 'class': cls}
return out