tfaip.trainer.warmstart

WarmStartParams

Definition of the WarmStartParams

class tfaip.trainer.warmstart.warmstart_params.WarmStartParams(model: Optional[str] = None, allow_partial: bool = False, trim_graph_name: bool = True, rename: List[str] = <factory>, add_suffix: str = '', rename_targets: List[str] = <factory>, exclude: Optional[str] = None, include: Optional[str] = None, auto_remove_numbers_for: List[str] = <factory>)

Bases: object

Parameters for warm-starting from a model.

model: Optional[str] = None
allow_partial: bool = False
trim_graph_name: bool = True
rename: List[str]
add_suffix: str = ''
rename_targets: List[str]
exclude: Optional[str] = None
include: Optional[str] = None
auto_remove_numbers_for: List[str]
create(**kwargs) WarmStarter
__init__(model: Optional[str] = None, allow_partial: bool = False, trim_graph_name: bool = True, rename: List[str] = <factory>, add_suffix: str = '', rename_targets: List[str] = <factory>, exclude: Optional[str] = None, include: Optional[str] = None, auto_remove_numbers_for: List[str] = <factory>) None
classmethod from_dict(kvs: Optional[Union[dict, list, str, int, float, bool]], *, infer_missing=False) dataclasses_json.api.A
classmethod from_json(s: Union[str, bytes, bytearray], *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw) dataclasses_json.api.A
classmethod schema(*, infer_missing: bool = False, only=None, exclude=(), many: bool = False, context=None, load_only=(), dump_only=(), partial: bool = False, unknown=None) dataclasses_json.mm.SchemaF[dataclasses_json.mm.A]
to_dict(encode_json=False, include_cls=True) Dict[str, Optional[Union[dict, list, str, int, float, bool]]]
to_json(*, skipkeys: bool = False, ensure_ascii: bool = True, check_circular: bool = True, allow_nan: bool = True, indent: Optional[Union[int, str]] = None, separators: Optional[Tuple[str, str]] = None, default: Optional[Callable] = None, sort_keys: bool = False, **kw) str

WarmStarter

Definition of the WarmStarter

tfaip.trainer.warmstart.warmstarter.longest_common_startstr(strs: List[str]) str
class tfaip.trainer.warmstart.warmstarter.WarmStarter(params: tfaip.trainer.warmstart.warmstarter.TWSP, **kwargs)

Bases: Generic[tfaip.trainer.warmstart.warmstarter.TWSP]

The WarmStarter handles the loading of a pretrained model an applies the weights to the current one.

See WarmStartParams for configuration. Both SavedModels and Checkpoints are supported.

__init__(params: tfaip.trainer.warmstart.warmstarter.TWSP, **kwargs)
property params: tfaip.trainer.warmstart.warmstarter.TWSP
warmstart(target_model: keras.engine.training.Model, custom_objects=None)
apply_weights(target_model, new_weights) NoReturn

By default, all weights of the target model are set to the new weights. This function can be overwritte, to handle setting the parameters. E.g. in ATR, if a Codec adaption should be done, i.e. only a sub set of one weight matrix is selected.

Parameters
  • target_model – Target model

  • new_weights – New weights of the model