Source code for edflow.main

import argparse
import importlib
import os
import yaml
import math
import datetime

from edflow.custom_logging import log, run
from edflow.util import get_obj_from_str, retrieve, set_value


def _save_config(config, prefix="config"):
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    fname = prefix + "_" + now + ".yaml"
    path = os.path.join(run.configs, fname)
    with open(path, "w") as f:
        f.write(yaml.dump(config))
    return path


# TODO: DRY --- train and test are almost the same


[docs]def train(config, root, checkpoint=None, retrain=False, debug=False): """Run training. Loads model, iterator and dataset according to config.""" from edflow.iterators.batches import make_batches # disable integrations in debug mode if debug: if retrieve(config, "debug/disable_integrations", default=True): integrations = retrieve(config, "integrations", default=dict()) for k in integrations: config["integrations"][k]["active"] = False max_steps = retrieve(config, "debug/max_steps", default=5 * 2) if max_steps > 0: config["num_steps"] = max_steps # backwards compatibility if not "datasets" in config: config["datasets"] = {"train": config["dataset"]} if "validation_dataset" in config: config["datasets"]["validation"] = config["validation_dataset"] log.set_log_target("train") logger = log.get_logger("train") logger.info("Starting Training.") model = get_obj_from_str(config["model"]) iterator = get_obj_from_str(config["iterator"]) datasets = dict( (split, get_obj_from_str(config["datasets"][split])) for split in config["datasets"] ) logger.info("Instantiating datasets.") for split in datasets: datasets[split] = datasets[split](config=config) datasets[split].expand = True logger.info("{} dataset size: {}".format(split, len(datasets[split]))) if debug: max_examples = retrieve( config, "debug/max_examples", default=5 * config["batch_size"] ) if max_examples > 0: logger.info( "Monkey patching {} dataset __len__ to {} examples".format( split, max_examples ) ) type(datasets[split]).__len__ = lambda self: max_examples n_processes = config.get("n_data_processes", min(16, config["batch_size"])) n_prefetch = config.get("n_prefetch", 1) logger.info("Building batches.") batches = dict() for split in datasets: batches[split] = make_batches( datasets[split], batch_size=config["batch_size"], shuffle=True, n_processes=n_processes, n_prefetch=n_prefetch, error_on_timeout=config.get("error_on_timeout", False), ) main_split = "train" try: if "num_steps" in config: # set number of epochs to perform at least num_steps steps steps_per_epoch = len(datasets[main_split]) / config["batch_size"] num_epochs = config["num_steps"] / steps_per_epoch config["num_epochs"] = math.ceil(num_epochs) else: steps_per_epoch = len(datasets[main_split]) / config["batch_size"] num_steps = config["num_epochs"] * steps_per_epoch config["num_steps"] = math.ceil(num_steps) logger.info("Instantiating model.") model = model(config) if not "hook_freq" in config: config["hook_freq"] = 1 compat_kwargs = dict( hook_freq=config["hook_freq"], num_epochs=config["num_epochs"] ) logger.info("Instantiating iterator.") iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs) logger.info("Initializing model.") if checkpoint is not None: iterator.initialize(checkpoint_path=checkpoint) else: iterator.initialize() if retrain: iterator.reset_global_step() # save current config logger.info("Starting Training with config:\n{}".format(yaml.dump(config))) cpath = _save_config(config, prefix="train") logger.info("Saved config at {}".format(cpath)) logger.info("Iterating.") iterator.iterate(batches) finally: for split in batches: batches[split].finalize()
[docs]def test(config, root, checkpoint=None, nogpu=False, bar_position=0, debug=False): """Run tests. Loads model, iterator and dataset from config.""" from edflow.iterators.batches import make_batches # backwards compatibility if not "datasets" in config: config["datasets"] = {"train": config["dataset"]} if "validation_dataset" in config: config["datasets"]["validation"] = config["validation_dataset"] log.set_log_target("latest_eval") logger = log.get_logger("test") logger.info("Starting Evaluation.") if "test_batch_size" in config: config["batch_size"] = config["test_batch_size"] if "test_mode" not in config: config["test_mode"] = True model = get_obj_from_str(config["model"]) iterator = get_obj_from_str(config["iterator"]) datasets = dict( (split, get_obj_from_str(config["datasets"][split])) for split in config["datasets"] ) logger.info("Instantiating datasets.") for split in datasets: datasets[split] = datasets[split](config=config) datasets[split].expand = True logger.info("{} dataset size: {}".format(split, len(datasets[split]))) if debug: max_examples = retrieve( config, "debug/max_examples", default=5 * config["batch_size"] ) if max_examples > 0: logger.info( "Monkey patching {} dataset __len__ to {} examples".format( split, max_examples ) ) type(datasets[split]).__len__ = lambda self: max_examples n_processes = config.get("n_data_processes", min(16, config["batch_size"])) n_prefetch = config.get("n_prefetch", 1) logger.info("Building batches.") batches = dict() for split in datasets: batches[split] = make_batches( datasets[split], batch_size=config["batch_size"], shuffle=False, n_processes=n_processes, n_prefetch=n_prefetch, error_on_timeout=config.get("error_on_timeout", False), ) try: logger.info("Initializing model.") model = model(config) config["hook_freq"] = 1 config["num_epochs"] = 1 config["nogpu"] = nogpu compat_kwargs = dict( hook_freq=config["hook_freq"], bar_position=bar_position, nogpu=config["nogpu"], num_epochs=config["num_epochs"], ) iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs) logger.info("Initializing model.") if checkpoint is not None: iterator.initialize(checkpoint_path=checkpoint) else: iterator.initialize() # save current config logger.info("Starting Evaluation with config:\n{}".format(yaml.dump(config))) prefix = "eval" if bar_position > 0: prefix = prefix + str(bar_position) cpath = _save_config(config, prefix=prefix) logger.info("Saved config at {}".format(cpath)) logger.info("Iterating") while True: iterator.iterate(batches) if not config.get("eval_forever", False): break finally: for split in batches: batches[split].finalize()