edflow.iterators.tf_trainer module¶
Summary¶
Classes:
Same but based on TFHookedModelIterator. |
|
Adds multistage training to Edflow Trainer |
Reference¶
-
class
edflow.iterators.tf_trainer.
TFBaseTrainer
(config, root, model, **kwargs)[source]¶ Bases:
edflow.iterators.tf_iterator.TFHookedModelIterator
Same but based on TFHookedModelIterator.
-
__init__
(config, root, model, **kwargs)[source]¶ Constructor.
- Parameters
model (object) – Model class.
num_epochs (int) – Number of times to iterate over the data.
hooks (list) – List containing
Hook
instances.hook_freq (int) – Frequency at which hooks are evaluated.
bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.
-
initialize
(checkpoint_path=None)[source]¶ Initialize from scratch or restore and keep restorer around.
-
step_ops
()[source]¶ Defines ops that are called at each step.
- Returns
- Return type
The operation run at each step.
-
-
class
edflow.iterators.tf_trainer.
TFListTrainer
(config, root, model, **kwargs)[source]¶
-
class
edflow.iterators.tf_trainer.
TFMultiStageTrainer
(config, root, model, **kwargs)[source]¶ Bases:
edflow.iterators.tf_trainer.TFBaseTrainer
Adds multistage training to Edflow Trainer
Stages are defined through the config. For example
- stages:
- 1:
name: pretrain end: 10 losses: []
- 2:
name: retrain end: 30 losses: [“model”]
- 3:
name: train losses: [“model”]
The stages are sorted by their key. It is recommended to keep the simple numeric ordering. In each stage, a set of losses can be specified through the losses : [ “loss1”, “loss2”, …] syntax. The duration of each stage is given by the end : num_steps value. Note that the end of a stage is determined in the order of the stages. A later stage has to have a higher end value then the previous one.
The model has to implement the edflowiterators.tf_trainer.TFMultiStageModel interface. Look at the multistage_trainer example.
-
__init__
(config, root, model, **kwargs)[source]¶ Constructor.
- Parameters
model (object) – Model class.
num_epochs (int) – Number of times to iterate over the data.
hooks (list) – List containing
Hook
instances.hook_freq (int) – Frequency at which hooks are evaluated.
bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.