tfaip.trainer.callbacks

Early Stopping

Definition of the EarlyStoppingParams

class tfaip.trainer.callbacks.earlystopping.params.EarlyStoppingParams(best_model_output_dir: Optional[str] = None, best_model_name: str = 'best', frequency: int = 1, n_to_go: int = - 1, lower_threshold: float = - 1e+100, upper_threshold: float = 1e+100, mode: Optional[str] = None, current: Optional[float] = None, monitor: Optional[str] = None, n: int = 1)

Bases: object

EarlyStoppingParameters

best_model_output_dir: Optional[str] = None
best_model_name: str = 'best'
frequency: int = 1
n_to_go: int = -1
lower_threshold: float = -1e+100
upper_threshold: float = 1e+100
mode: Optional[str] = None
current: Optional[float] = None
monitor: Optional[str] = None
n: int = 1
__init__(best_model_output_dir: Optional[str] = None, best_model_name: str = 'best', frequency: int = 1, n_to_go: int = - 1, lower_threshold: float = - 1e+100, upper_threshold: float = 1e+100, mode: Optional[str] = None, current: Optional[float] = None, monitor: Optional[str] = None, n: int = 1)None

Initialize self. See help(type(self)) for accurate signature.

classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False)A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None)dataclasses_json.mm.SchemaF[A]
to_dict(encode_json=False)Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw)str

Definition of the EarlyStoppingCallback

class tfaip.trainer.callbacks.earlystopping.callback.EarlyStoppingCallback(scenario: ScenarioBase, trainer_params: TrainerParams)

Bases: tensorflow.python.keras.callbacks.Callback

Callback that implements early stopping and also (always) tracks the best model.

__init__(scenario: ScenarioBase, trainer_params: TrainerParams)

Initialize self. See help(type(self)) for accurate signature.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

BenchmarkCallback

Definition of the BenchmarkCallback

class tfaip.trainer.callbacks.benchmark_callback.RunningAverage

Bases: object

Compute the running average value

__init__()

Initialize self. See help(type(self)) for accurate signature.

add(v)
reset()
class tfaip.trainer.callbacks.benchmark_callback.BenchmarkCallback

Bases: tensorflow.python.keras.callbacks.Callback

The BenchmarkCallback will trace the training and validation times per patch, epoch, and in total.

__init__()

Initialize self. See help(type(self)) for accurate signature.

print()
on_train_begin(logs=None)

Called at the beginning of training.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_train_end(logs=None)

Called at the end of training.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently the output of the last call to on_epoch_end() is passed to this argument for this method but that may change in the future.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

on_train_batch_begin(batch, logs=None)

Called at the beginning of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict, contains the return value of model.train_step. Typically, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_test_end(logs=None)

Called at the end of evaluation or validation.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently the output of the last call to on_test_batch_end() is passed to this argument for this method but that may change in the future.

EMACallback

Definition of the EMACallback

class tfaip.trainer.callbacks.ema_callback.EMACallback(optimizer: tfaip.trainer.optimizer.weights_moving_average.WeightsMovingAverage)

Bases: tensorflow.python.keras.callbacks.Callback

The EMACallback swaps the weights of the model with EMA or non EMA which is required for validation and export.

For example, at the begin of testing the EMA weights are loaded, and at the end the original weigs are restored. Similarly, at the end of a epoch the EMA weights are loaded to export the prediction model, and at the end of each epoch the weights are reset to the actual weights.

__init__(optimizer: tfaip.trainer.optimizer.weights_moving_average.WeightsMovingAverage)

Initialize self. See help(type(self)) for accurate signature.

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_test_end(logs=None)

Called at the end of evaluation or validation.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently the output of the last call to on_test_batch_end() is passed to this argument for this method but that may change in the future.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

ExtractLogsCallback

Definition of the ExtractLogsCallback

class tfaip.trainer.callbacks.extract_logs.ExtractLogsCallback(tensorboard_data_handler: tfaip.trainer.callbacks.tensor_board_data_handler.TensorBoardDataHandler)

Bases: tensorflow.python.keras.callbacks.Callback

This callback is a utility to extract variables from the log that can not be handled by all callbacks.

The actual use-case is to log custom data to the TensorBoard (e.g., bytes or images) The values will be added to the logs since they are “Metrics” (this is a bit hacky…), however they must immediately removed from the logs as they are “not real logs” to be displayed, but only used by the TensorBoardCallback.

Thereto, all logs are extracted and stored in a separated data structure which is cleared on the begin of the training. The TensorBoardCallback has then access to the extracted logs.

__init__(tensorboard_data_handler: tfaip.trainer.callbacks.tensor_board_data_handler.TensorBoardDataHandler)

Initialize self. See help(type(self)) for accurate signature.

on_train_begin(logs=None)

Called at the beginning of training.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

on_predict_batch_end(batch, logs=None)

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

on_test_batch_end(batch, logs=None)

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

extract(logs, prefix='')

FixLogLabels

Definition of the FixLogLabelsCallback

class tfaip.trainer.callbacks.fix_logs_labels.FixLogLabelsCallback

Bases: tensorflow.python.keras.callbacks.Callback

By default tensorflow labels the metrics (metric.name) by a functions name even though they were correctly named as metric. The same holds for the names of the losses. This callback stores the original correct names, and renames the keys of the logs to be correct by calling fix.

__init__()

Initialize self. See help(type(self)) for accurate signature.

on_train_begin(logs=None)

Called at the beginning of training.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

on_predict_batch_end(batch, logs=None)

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

on_test_batch_end(batch, logs=None)

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

fix(logs: dict)

LAVCallback

Definition of the LAVCallback

class tfaip.trainer.callbacks.lav_callback.LAVCallback(trainer_params: TrainerParams, scenario: ScenarioBase, extract_logs_cb: ExtractLogsCallback)

Bases: tensorflow.python.keras.callbacks.Callback

This callback runs LAV at the end of a epoch.

All output metrics of LAV are added to the logs (prefix lav_) and can thus be accessed in other callbacks. Therefore, LAV results are also added to the tensorboard (with a custom LAV handler)

__init__(trainer_params: TrainerParams, scenario: ScenarioBase, extract_logs_cb: ExtractLogsCallback)

Initialize self. See help(type(self)) for accurate signature.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

LoggerCallback

Definition of the LoggerCallback

class tfaip.trainer.callbacks.logger_callback.LoggerCallback

Bases: tensorflow.python.keras.callbacks.Callback

The logger callback prints usefule information about the training process: - log at the end of a epoch a - Write the current epoch

This is required for the train.log where the progress bar and thus the metrics are not written to.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.

ProgbarCallback

Definition of the TFAIPProgbarLogger which extends the default keras ProgbarLogger

class tfaip.trainer.callbacks.progbar.TFAIPProgbarLogger(delta_time=5, **kwargs)

Bases: tensorflow.python.keras.callbacks.ProgbarLogger

Callback to render the progress bar during trainer. This implementation of the default ProgbarLogger ads an additional mode (self.verbose == 2), Where instead of a progress bar the output is logged each delta_time seconds (default 5).

__init__(delta_time=5, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

TensorBoardCallback

Definition of the TensorBoardCallback

class tfaip.trainer.callbacks.tensor_board_callback.TensorBoardCallback(*args, **kwargs)

Bases: tensorflow.python.keras.callbacks.TensorBoard

Custom implementation fo the TensorBoard-Callback of keras to provide additional functionality:

  • Logging of LAV (custom LAV writer)

  • Adding of the learning rate

  • Applying a TensorBoardDataHandler on every log output to handle custom tensorboard data (e.g. PR-Curves or images)

__init__(log_dir, steps_per_epoch, extracted_logs_cb: ExtractLogsCallback, data_handler: tfaip.trainer.callbacks.tensor_board_data_handler.TensorBoardDataHandler, reset=False, profile=0, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)

Runs metrics and histogram summaries at epoch end.

TensorBoardDataHandlerCallback

Definition of the TensorBoardDataHandler

class tfaip.trainer.callbacks.tensor_board_data_handler.TensorBoardDataHandler

Bases: object

The TensorBoardDataHandler allows to customize writing arbitrary data to the TensorBoard.

Use case: Writing image (see tfaip.scenario.tutorial.full)
  • Add the raw image data (e.g. weights) as output of the model by implementing _outputs_for_tensorboard

  • Overwrite handle to write the data adapted data to the Tensorboard (e.g. tf.summary.write_image())

Use case: PR Curve
  • The PR-Curve is a Metric of bytes.

  • Add the name of the metric to _tensorboard_only_metrics (to mark this metrik to be only added to tensorboard) Note, this step is optional, if the type of the metric is bytes.

  • Overwrite handle to write the actual data with the raw tensorboard writer

__init__()

Initialize self. See help(type(self)) for accurate signature.

setup(inputs, outputs)Dict[str, Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.keras.engine.keras_tensor.KerasTensor]]
handle(name, name_for_tb, value, step)
is_tensorboard_only(key: str, value: Union[tensorflow.python.framework.ops.Tensor, tensorflow.python.keras.engine.keras_tensor.KerasTensor])

TensorflowFixCallback

Implementation of a TensorflowFix

class tfaip.trainer.callbacks.tensorflow_fix.TensorflowFix

Bases: tensorflow.python.keras.callbacks.Callback

Fix for a weired Tensorflow bug. Remove this if the Issue is closed or Fixed…

See https://github.com/tensorflow/tensorflow/issues/42872

__init__()

Initialize self. See help(type(self)) for accurate signature.

on_train_begin(logs=None)

Called at the beginning of training.

Subclasses should override for any actions to run.

Parameters

logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_train_batch_end(batch, logs=None)

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Parameters
  • batch – Integer, index of batch within the current epoch.

  • logs – Dict. Aggregated metric results up until this batch.

TrainParamsLoggerCallback

Definition of the TrainParamsLogger

class tfaip.trainer.callbacks.train_params_logger.TrainerCheckpointsCallback(train_params, save_freq=None, store_weights=True, store_params=True)

Bases: tensorflow.python.keras.callbacks.ModelCheckpoint

Callback to store the current state of the trainer params and the current training model with all of its weights which is required for resuming the training.

This is realized by reimplementing some of the methods of the keras ModelCheckpoint-Callback which is a base class.

__init__(train_params, save_freq=None, store_weights=True, store_params=True)

Initialize self. See help(type(self)) for accurate signature.

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs

    Dict, metric results for this training epoch, and for the

    validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

    Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.