edflow.iterators.tf_trainer module¶
Summary¶
Classes:
Same but based on TFHookedModelIterator. |
|
Adds multistage training to Edflow Trainer |
Functions:
Linear from \((a, \alpha)\) to \((b, \beta)\), i.e. |
Reference¶
-
edflow.iterators.tf_trainer.make_linear_var(step, start, end, start_value, end_value, clip_min=None, clip_max=None, **kwargs)[source]¶ Linear from \((a, \alpha)\) to \((b, \beta)\), i.e. \(y = (\beta - \alpha)/(b - a) * (x - a) + \alpha\)
- Parameters
step (tf.Tensor) – \(x\)
start (int) – \(a\)
end (int) – \(b\)
start_value (float) – \(\alpha\)
end_value (float) – \(\beta\)
clip_min (int) – Minimal value returned.
clip_max (int) – Maximum value returned.
- Returns
:math:`y`
- Return type
tf.Tensor
-
class
edflow.iterators.tf_trainer.TFBaseTrainer(config, root, model, **kwargs)[source]¶ Bases:
edflow.iterators.tf_iterator.TFHookedModelIteratorSame but based on TFHookedModelIterator.
-
__init__(config, root, model, **kwargs)[source]¶ Constructor.
- Parameters
model (object) – Model class.
num_epochs (int) – Number of times to iterate over the data.
hooks (list) – List containing
Hookinstances.hook_freq (int) – Frequency at which hooks are evaluated.
bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.
-
initialize(checkpoint_path=None)[source]¶ Initialize from scratch or restore and keep restorer around.
-
step_ops()[source]¶ Defines ops that are called at each step.
- Returns
- Return type
The operation run at each step.
-
-
class
edflow.iterators.tf_trainer.TFListTrainer(config, root, model, **kwargs)[source]¶
-
class
edflow.iterators.tf_trainer.TFMultiStageTrainer(config, root, model, **kwargs)[source]¶ Bases:
edflow.iterators.tf_trainer.TFBaseTrainerAdds multistage training to Edflow Trainer
Stages are defined through the config. For example
- stages:
- 1:
name: pretrain end: 10 losses: []
- 2:
name: retrain end: 30 losses: [“model”]
- 3:
name: train losses: [“model”]
The stages are sorted by their key. It is recommended to keep the simple numeric ordering. In each stage, a set of losses can be specified through the losses : [ “loss1”, “loss2”, …] syntax. The duration of each stage is given by the end : num_steps value. Note that the end of a stage is determined in the order of the stages. A later stage has to have a higher end value then the previous one.
The model has to implement the edflowiterators.tf_trainer.TFMultiStageModel interface. Look at the multistage_trainer example.
-
__init__(config, root, model, **kwargs)[source]¶ Constructor.
- Parameters
model (object) – Model class.
num_epochs (int) – Number of times to iterate over the data.
hooks (list) – List containing
Hookinstances.hook_freq (int) – Frequency at which hooks are evaluated.
bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.