Source code for edflow.iterators.tf_iterator

import tensorflow as tf
from edflow.iterators.model_iterator import PyHookedModelIterator


[docs]class TFHookedModelIterator(PyHookedModelIterator):
[docs] def make_feeds(self, batch): feeds = { pl: batch[name] for name, pl in self.model.inputs.items() if name in batch } return feeds
[docs] def run(self, fetches, feed_dict): get_global_step = fetches.pop("global_step") results = self.session.run(fetches, feed_dict=feed_dict) results["global_step"] = get_global_step() return results
[docs] def iterate(self, batch_iterator, validation_batch_iterator=None): with self.session.as_default(): super().iterate(batch_iterator)
@property def session(self): # session that is initialized the first time it is needed if hasattr(self, "_session"): return self._session sess_config = tf.ConfigProto() if self.config.get("nogpu", False): self.logger.info("Hiding GPUs.") sess_config.device_count["GPU"] = 0 sess_config.gpu_options.allow_growth = self.config.get("gpu_mem_growth", False) gpu_mem_fraction = self.config.get("gpu_mem_fraction", None) if gpu_mem_fraction is not None: self.logger.info("Setting GPU MEM Fraction to {}".format(gpu_mem_fraction)) sess_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_fraction self._session = tf.Session(config=sess_config) return self._session