tfaip.trainer

Trainer

Definition of the Trainer

tfaip.trainer.trainer.write_scalar_summaries(logs, step)
tfaip.trainer.trainer.to_placeholder(ts)
class tfaip.trainer.trainer.Trainer(params: tfaip.trainer.trainer.TTrainerParams, scenario: tfaip.scenario.scenariobase.ScenarioBase, restore=False)

Bases: Generic[tfaip.trainer.trainer.TTrainerParams], abc.ABC

The Trainer class is typically identical for all scenarios. Its purpose is to set up the training callbacks, Warmstarting/Restarting. The training loop is wrapped in ScenarioBase and is a call to keras.Model.fit.

classmethod params_cls() Type[tfaip.trainer.trainer.TTrainerParams]
static parse_trainer_params(d: Union[str, dict]) Tuple[tfaip.trainer.trainer.TTrainerParams, Type[tfaip.scenario.scenariobase.ScenarioBase]]
classmethod restore_trainer(checkpoint: Union[str, dict]) tfaip.trainer.trainer.Trainer
__init__(params: tfaip.trainer.trainer.TTrainerParams, scenario: tfaip.scenario.scenariobase.ScenarioBase, restore=False)
property scenario
property data: tfaip.data.data.DataBase
property params
setup_data()
setup_model()
train(**kwargs)
create_train_params_logger_callback(store_weights, store_params)
setup_callbacks(optimizer, callbacks=None)
setup_steps_per_epoch()
fit()
create_warmstarter(**kwargs) tfaip.trainer.warmstart.warmstarter.WarmStarter
export_trainable_pb()

Save the current graph (tf1 style) in a pb

Usage:
  • use the train_ema operation to train (incl. ema weights update). train_op is without ema.

  • save the model (incl. ema weights) using save_basic/Const so set the path, and use “save_basic/control_dependency” to save all weights

  • to restore all weights, load with “save_basic/restore_all”

  • to copy the ema weights into the training weights for prediction, call “save_ema/restore_all”

TrainerParams

Definition of the TrainerParams and the TrainerPipelineParamsBase

class tfaip.trainer.params.TrainerPipelines(train: tfaip.data.databaseparams.DataPipelineParams = <factory>, val: tfaip.data.databaseparams.DataPipelineParams = <factory>)

Bases: object

train: tfaip.data.databaseparams.DataPipelineParams
val: tfaip.data.databaseparams.DataPipelineParams
__init__(train: tfaip.data.databaseparams.DataPipelineParams = <factory>, val: tfaip.data.databaseparams.DataPipelineParams = <factory>) None
classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False) dataclasses_json.api.A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw) dataclasses_json.api.A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None) dataclasses_json.mm.SchemaF[dataclasses_json.mm.A]
to_dict(encode_json=False, include_cls=True) Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw) str
class tfaip.trainer.params.TrainerPipelineParamsBaseMeta(*args, **kwargs)

Bases: tfaip.util.generic_meta.CollectGenericTypes

class tfaip.trainer.params.TrainerPipelineParamsBase(setup: tfaip.trainer.params.TrainerPipelines = <factory>)

Bases: Generic[tfaip.trainer.params.TDataGeneratorTrain, tfaip.trainer.params.TDataGeneratorVal], abc.ABC

Definition of the training pipeline inputs.

Specify the DataGeneratorParams for Training and Validation as Generaics.

setup: tfaip.trainer.params.TrainerPipelines
classmethod train_cls()
classmethod val_cls()
abstract train_gen() tfaip.trainer.params.TDataGeneratorTrain
abstract val_gen() Optional[tfaip.trainer.params.TDataGeneratorVal]
lav_gen() Iterable[tfaip.trainer.params.TDataGeneratorVal]
train_data(data: DataBase) DataPipeline
val_data(data: DataBase) DataPipeline
lav_data(data: DataBase) Iterable[DataPipeline]
__init__(setup: tfaip.trainer.params.TrainerPipelines = <factory>) None
classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False) dataclasses_json.api.A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw) dataclasses_json.api.A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None) dataclasses_json.mm.SchemaF[dataclasses_json.mm.A]
to_dict(encode_json=False, include_cls=True) Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw) str
class tfaip.trainer.params.TrainerPipelineParamsMeta(*args, **kwargs)

Bases: tfaip.util.generic_meta.ReplaceDefaultDataClassFieldsMeta, tfaip.trainer.params.TrainerPipelineParamsBaseMeta

class tfaip.trainer.params.TrainerPipelineParams(*args, **kwds)

Bases: tfaip.trainer.params.TrainerPipelineParamsBase[tfaip.trainer.params.TDataGeneratorTrain, tfaip.trainer.params.TDataGeneratorVal]

train: tfaip.trainer.params.TDataGeneratorTrain
val: tfaip.trainer.params.TDataGeneratorVal
train_gen() tfaip.trainer.params.TDataGeneratorTrain
val_gen() tfaip.trainer.params.TDataGeneratorVal
__init__(setup: tfaip.trainer.params.TrainerPipelines = <factory>, train: tfaip.trainer.params.TDataGeneratorTrain = <factory>, val: tfaip.trainer.params.TDataGeneratorVal = <factory>) None
classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False) dataclasses_json.api.A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw) dataclasses_json.api.A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None) dataclasses_json.mm.SchemaF[dataclasses_json.mm.A]
to_dict(encode_json=False, include_cls=True) Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw) str
class tfaip.trainer.params.TrainerParamsMeta(*args, **kwargs)

Bases: tfaip.util.generic_meta.ReplaceDefaultDataClassFieldsMeta

Meta class for the trainer params

The class will automatically replace the defaults of the scenario on gen field with the actual class passed to the Generic.

class tfaip.trainer.params.TrainerParams(epochs: int = 100, current_epoch: int = 0, samples_per_epoch: int = -1, scale_epoch_size: float = 1, train_accum_steps: int = 1, progress_bar_mode: int = 1, progbar_delta_time: float = 5, tf_cpp_min_log_level: int = 2, force_eager: bool = False, skip_model_load_test: bool = False, export_training_graph_path: Optional[str] = None, val_every_n: int = 1, lav_every_n: int = 0, lav_silent: bool = True, lav_min_epoch: int = 0, output_dir: Optional[str] = None, write_checkpoints: bool = True, export_best: Optional[bool] = None, export_final: bool = True, no_train_scope: Optional[str] = None, ema_decay: float = 0.0, random_seed: Optional[int] = None, profile: bool = False, profile_steps: List[int] = <factory>, device: tfaip.device.device_config.DeviceConfigParams = <factory>, optimizer: tfaip.trainer.optimizer.optimizers.OptimizerParams = <factory>, learning_rate: tfaip.trainer.scheduler.learningrate_params.LearningRateParams = <factory>, scenario: tfaip.trainer.params.TScenarioParams = <factory>, warmstart: tfaip.trainer.warmstart.warmstart_params.WarmStartParams = <factory>, early_stopping: tfaip.trainer.callbacks.earlystopping.params.EarlyStoppingParams = <factory>, gen: tfaip.trainer.params.TTrainerPipelineParams = <factory>, preload_data: bool = False, preload_data_progress_bar: bool = True, saved_checkpoint_sub_dir: str = '', checkpoint_sub_dir: str = '', checkpoint_save_freq: Union[str, int] = 'epoch')

Bases: Generic[tfaip.trainer.params.TScenarioParams, tfaip.trainer.params.TTrainerPipelineParams], abc.ABC

TrainerParams storing hyper-parameters, the ScenarioBaseParams, and the TrainerPipelineParams

epochs: int = 100
current_epoch: int = 0
samples_per_epoch: int = -1
scale_epoch_size: float = 1
train_accum_steps: int = 1
progress_bar_mode: int = 1
progbar_delta_time: float = 5
tf_cpp_min_log_level: int = 2
force_eager: bool = False
skip_model_load_test: bool = False
export_training_graph_path: Optional[str] = None
val_every_n: int = 1
lav_every_n: int = 0
lav_silent: bool = True
lav_min_epoch: int = 0
output_dir: Optional[str] = None
write_checkpoints: bool = True
export_best: Optional[bool] = None
export_final: bool = True
no_train_scope: Optional[str] = None
ema_decay: float = 0.0
random_seed: Optional[int] = None
profile: bool = False
profile_steps: List[int]
device: tfaip.device.device_config.DeviceConfigParams
optimizer: tfaip.trainer.optimizer.optimizers.OptimizerParams
learning_rate: tfaip.trainer.scheduler.learningrate_params.LearningRateParams
scenario: tfaip.trainer.params.TScenarioParams
warmstart: tfaip.trainer.warmstart.warmstart_params.WarmStartParams
early_stopping: tfaip.trainer.callbacks.earlystopping.params.EarlyStoppingParams
gen: tfaip.trainer.params.TTrainerPipelineParams
preload_data: bool = False
preload_data_progress_bar: bool = True
saved_checkpoint_sub_dir: str = ''
checkpoint_sub_dir: str = ''
checkpoint_save_freq: Union[str, int] = 'epoch'
__init__(epochs: int = 100, current_epoch: int = 0, samples_per_epoch: int = -1, scale_epoch_size: float = 1, train_accum_steps: int = 1, progress_bar_mode: int = 1, progbar_delta_time: float = 5, tf_cpp_min_log_level: int = 2, force_eager: bool = False, skip_model_load_test: bool = False, export_training_graph_path: Optional[str] = None, val_every_n: int = 1, lav_every_n: int = 0, lav_silent: bool = True, lav_min_epoch: int = 0, output_dir: Optional[str] = None, write_checkpoints: bool = True, export_best: Optional[bool] = None, export_final: bool = True, no_train_scope: Optional[str] = None, ema_decay: float = 0.0, random_seed: Optional[int] = None, profile: bool = False, profile_steps: List[int] = <factory>, device: tfaip.device.device_config.DeviceConfigParams = <factory>, optimizer: tfaip.trainer.optimizer.optimizers.OptimizerParams = <factory>, learning_rate: tfaip.trainer.scheduler.learningrate_params.LearningRateParams = <factory>, scenario: tfaip.trainer.params.TScenarioParams = <factory>, warmstart: tfaip.trainer.warmstart.warmstart_params.WarmStartParams = <factory>, early_stopping: tfaip.trainer.callbacks.earlystopping.params.EarlyStoppingParams = <factory>, gen: tfaip.trainer.params.TTrainerPipelineParams = <factory>, preload_data: bool = False, preload_data_progress_bar: bool = True, saved_checkpoint_sub_dir: str = '', checkpoint_sub_dir: str = '', checkpoint_save_freq: Union[str, int] = 'epoch') None
classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False) dataclasses_json.api.A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw) dataclasses_json.api.A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None) dataclasses_json.mm.SchemaF[dataclasses_json.mm.A]
to_dict(encode_json=False, include_cls=True) Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw) str