tfaip.trainer.warmstart

WarmStartParams

Definition of the WarmStartParams

class tfaip.trainer.warmstart.warmstart_params.WarmStartParams(model: Optional[str] = None, allow_partial: bool = False, trim_graph_name: bool = True, rename: List[str] = <factory>, add_suffix: str = '', rename_targets: List[str] = <factory>, exclude: Optional[str] = None, include: Optional[str] = None, auto_remove_numbers_for: List[str] = <factory>)

Bases: object

Parameters for warm-starting from a model.

model: Optional[str] = None
allow_partial: bool = False
trim_graph_name: bool = True
rename: List[str]
add_suffix: str = ''
rename_targets: List[str]
exclude: Optional[str] = None
include: Optional[str] = None
auto_remove_numbers_for: List[str]
__init__(model: Optional[str] = None, allow_partial: bool = False, trim_graph_name: bool = True, rename: List[str] = <factory>, add_suffix: str = '', rename_targets: List[str] = <factory>, exclude: Optional[str] = None, include: Optional[str] = None, auto_remove_numbers_for: List[str] = <factory>)None

Initialize self. See help(type(self)) for accurate signature.

classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False)A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None)dataclasses_json.mm.SchemaF[A]
to_dict(encode_json=False)Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw)str

WarmStarter

Definition of the WarmStarter

tfaip.trainer.warmstart.warmstarter.longest_common_startstr(strs: List[str])str
class tfaip.trainer.warmstart.warmstarter.WarmStarter(params: tfaip.trainer.warmstart.warmstart_params.WarmStartParams)

Bases: object

The WarmStarter handles the loading of a pretrained model an applies the weights to the current one.

See WarmStartParams for configuration. Both SavedModels and Checkpoints are supported.

__init__(params: tfaip.trainer.warmstart.warmstart_params.WarmStartParams)

Initialize self. See help(type(self)) for accurate signature.

warmstart(target_model: tensorflow.python.keras.engine.training.Model, custom_objects=None)
apply_weights(target_model, new_weights)NoReturn

By default, all weights of the target model are set to the new weights. This function can be overwritte, to handle setting the parameters. E.g. in ATR, if a Codec adaption should be done, i.e. only a sub set of one weight matrix is selected.

Parameters
  • target_model – Target model

  • new_weights – New weights of the model