edflow.iterators.tf_trainer module

Summary

Classes:

TFBaseTrainer

Same but based on TFHookedModelIterator.

TFFrequencyTrainer

TFListTrainer

TFMultiStageModel

TFMultiStageTrainer

Adds multistage training to Edflow Trainer

Reference

class edflow.iterators.tf_trainer.TFBaseTrainer(config, root, model, **kwargs)[source]

Bases: edflow.iterators.tf_iterator.TFHookedModelIterator

Same but based on TFHookedModelIterator.

__init__(config, root, model, **kwargs)[source]

Constructor.

Parameters
  • model (object) – Model class.

  • num_epochs (int) – Number of times to iterate over the data.

  • hooks (list) – List containing Hook instances.

  • hook_freq (int) – Frequency at which hooks are evaluated.

  • bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.

initialize(checkpoint_path=None)[source]

Initialize from scratch or restore and keep restorer around.

step_ops()[source]

Defines ops that are called at each step.

Returns

Return type

The operation run at each step.

make_feeds(batch)[source]

Put global step into batches and add all extra required placeholders from batches.

setup()[source]

Init train_placeholders, log_ops and img_ops which can be added to.

create_train_op()[source]

Default optimizer + optimize each submodule

make_loss_ops()[source]

Return per submodule loss. Can add tensors to log_ops and img_ops

make_run_once_op()[source]

Return op to be run at step zero. Used for custom initialization etc.

get_trainable_variables(submodule)[source]
get_init_variables()[source]
get_restore_variables()[source]
get_checkpoint_variables()[source]
run(fetches, feed_dict)[source]

Runs all fetch ops and stores the results.

Parameters
  • fetches (dict) – name: Callable pairs.

  • feed_dict (dict) – Passed as kwargs to all fetch ops

Returns

name: results pairs.

Return type

dict

class edflow.iterators.tf_trainer.TFFrequencyTrainer(config, root, model, **kwargs)[source]

Bases: edflow.iterators.tf_trainer.TFBaseTrainer

create_train_op()[source]

Default optimizer + optimize each submodule

run(fetches, feed_dict)[source]

Runs all fetch ops and stores the results.

Parameters
  • fetches (dict) – name: Callable pairs.

  • feed_dict (dict) – Passed as kwargs to all fetch ops

Returns

name: results pairs.

Return type

dict

class edflow.iterators.tf_trainer.TFListTrainer(config, root, model, **kwargs)[source]

Bases: edflow.iterators.tf_trainer.TFBaseTrainer

create_train_op()[source]

Default optimizer + optimize each submodule

run(fetches, feed_dict)[source]

Runs all fetch ops and stores the results.

Parameters
  • fetches (dict) – name: Callable pairs.

  • feed_dict (dict) – Passed as kwargs to all fetch ops

Returns

name: results pairs.

Return type

dict

get_learning_rate_multiplier(i)[source]
class edflow.iterators.tf_trainer.TFMultiStageTrainer(config, root, model, **kwargs)[source]

Bases: edflow.iterators.tf_trainer.TFBaseTrainer

Adds multistage training to Edflow Trainer

Stages are defined through the config. For example

stages:
1:

name: pretrain end: 10 losses: []

2:

name: retrain end: 30 losses: [“model”]

3:

name: train losses: [“model”]

The stages are sorted by their key. It is recommended to keep the simple numeric ordering. In each stage, a set of losses can be specified through the losses : [ “loss1”, “loss2”, …] syntax. The duration of each stage is given by the end : num_steps value. Note that the end of a stage is determined in the order of the stages. A later stage has to have a higher end value then the previous one.

The model has to implement the edflowiterators.tf_trainer.TFMultiStageModel interface. Look at the multistage_trainer example.

__init__(config, root, model, **kwargs)[source]

Constructor.

Parameters
  • model (object) – Model class.

  • num_epochs (int) – Number of times to iterate over the data.

  • hooks (list) – List containing Hook instances.

  • hook_freq (int) – Frequency at which hooks are evaluated.

  • bar_position (int) – Used by tqdm to place bars at the right position when using multiple Iterators in parallel.

create_train_op()[source]

Default optimizer + optimize each submodule

run(fetches, feed_dict)[source]

Runs all fetch ops and stores the results.

Parameters
  • fetches (dict) – name: Callable pairs.

  • feed_dict (dict) – Passed as kwargs to all fetch ops

Returns

name: results pairs.

Return type

dict

get_current_train_op(current_stage)[source]
determine_current_stage()[source]
class edflow.iterators.tf_trainer.TFMultiStageModel[source]

Bases: object

__init__()[source]

Initialize self. See help(type(self)) for accurate signature.

property stage_update_op
property stage_placeholder
property stage