Source code for edflow.data.util.cached_dset

from tqdm import tqdm, trange
from zipfile import ZipFile, ZIP_DEFLATED  # , ZIP_BZIP2, ZIP_LZMA
import numpy as np
import os
import pickle

from multiprocessing.managers import BaseManager
import queue

from edflow.data.dataset_mixin import DatasetMixin


[docs]def make_server_manager(port=63127, authkey=b"edcache"): inqueue = queue.Queue() outqueue = queue.Queue() class InOutManager(BaseManager): pass InOutManager.register("get_inqueue", lambda: inqueue) InOutManager.register("get_outqueue", lambda: outqueue) manager = InOutManager(address=("", port), authkey=authkey) manager.start() print("Started manager server at {}".format(manager.address)) return manager
[docs]def make_client_manager(ip, port=63127, authkey=b"edcache"): class InOutManager(BaseManager): pass InOutManager.register("get_inqueue") InOutManager.register("get_outqueue") manager = InOutManager(address=(ip, port), authkey=authkey) manager.connect() print("Connected to server at {}".format(manager.address)) return manager
[docs]def pickle_and_queue( dataset_factory, inqueue, outqueue, naming_template="example_{}.p" ): """Parallelizable function to retrieve and queue examples from a Dataset. Parameters ---------- dataset_factory : chainer.DatasetMixin A dataset factory, with methods described in :class:`CachedDataset`. indices : list List of indices, used to retrieve samples from dataset. queue : mp.Queue Queue to put the samples in. naming_template : str Formatable string, which defines the name of the stored file given its index. """ pbar = tqdm() dataset = dataset_factory() while True: try: indices = inqueue.get_nowait() except queue.Empty: return for idx in indices: try: example = dataset[idx] except BaseException: print("Error getting example {}".format(idx)) raise pickle_name = naming_template.format(idx) pickle_bytes = pickle.dumps(example) outqueue.put([pickle_name, pickle_bytes]) pbar.update(1)
[docs]class ExamplesFolder(object): """Contains all examples and labels of a cached dataset."""
[docs] def __init__(self, root): self.root = root
[docs] def read(self, name): with open(os.path.join(self.root, name), "rb") as example: return example.read()
class _CacheDataset(DatasetMixin): """Only used to avoid initializing the original dataset.""" def __init__(self, root, name, _legacy=True): self.root = root self.name = name filespath = os.path.join(root, "cached", name) if _legacy: zippath = filespath + ".zip" # naming_template = 'example_{}.p' with ZipFile(zippath, "r") as zip_f: filenames = zip_f.namelist() else: filenames = os.listdir(filespath) def is_example(name): return name.startswith("example_") and name.endswith(".p") examplefilenames = [n for n in filenames if is_example(n)] self.n = len(examplefilenames) def __len__(self): return self.n
[docs]class CachedDataset(DatasetMixin): """Using a Dataset of single examples creates a cached (saved to memory) version, which can be accessed way faster at runtime. To avoid creating the dataset multiple times, it is checked if the cached version already exists. Calling `__getitem__` on this class will try to retrieve the samples from the cached dataset to reduce the preprocessing overhead. The cached dataset will be stored in the root directory of the base dataset in the subfolder `cached` with name `name.zip`. Besides the usual DatasetMixin interface, datasets to be cached must also implement root # (str) root folder to cache into name # (str) unqiue name Optionally but highly recommended, they should provide in_memory_keys # list(str) keys which will be collected from examples The collected values are stored in a dict of list, mapping an in_memory_key to a list containing the i-ths value at the i-ths place. This data structure is then exposed via the attribute `labels` and enables rapid iteration over useful labels without loading each example seperately. That way, downstream datasets can filter the indices of the cached dataset efficiently, e.g. filtering based on train/eval splits. Caching proceeds as follows: Expose a method which returns the dataset to be cached, e.g. def DataToCache(): path = "/path/to/data" return MyCachableDataset(path) Start caching server on host <server_ip_or_hostname>: edcache --server --dataset import.path.to.DataToCache Wake up a worker bee on same or different hosts: edcache --address <server_ip_or_hostname> --dataset import.path.to.DataCache # noqa Start a cacherhive! """ _legacy = True
[docs] def __init__( self, dataset, force_cache=False, keep_existing=True, _legacy=True, chunk_size=64, ): """Given a dataset class, stores all examples in the dataset, if this has not yet happened. Parameters ---------- dataset : object Dataset class which defines the following methods: \n - `root`: returns the path to the raw data \n - `name`: returns the name of the dataset -> best be unique \n - `__len__`: number of examples in the dataset \n - `__getitem__`: returns a sindle datum \n - `in_memory_keys`: returns all keys, that are stored \n alongside the dataset, in a `labels.p` file. This allows to retrive labels more quickly and can be used to filter the data more easily. force_cache : bool If True the dataset is cached even if an existing, cached version is overwritten. keep_existing : bool If True, existing entries in cache will not be recomputed and only non existing examples are appended to the cache. Useful if caching was interrupted. _legacy : bool Read from the cached Zip file. Deprecated mode. Future Datasets should not write into zips as read times are very long. chunksize : int Length of the index list that is sent to the worker. """ self.force_cache = force_cache self.keep_existing = keep_existing self._legacy = _legacy self.base_dataset = dataset self._root = root = dataset.root name = dataset.name self.chunk_size = chunk_size self.store_dir = os.path.join(root, "cached") self.store_path = os.path.join(self.store_dir, name) if _legacy: self.store_path += ".zip" # leading_zeroes = str(len(str(len(self)))) # self.naming_template = 'example_{:0>' + leading_zeroes + '}.p' # above might be better, but for compatibility we need this right # now, because pickle_and_queue did not receive the updated template self.naming_template = "example_{}.p" self._labels_name = "labels.p" os.makedirs(self.store_dir, exist_ok=True) if self.force_cache: self.cache_dataset()
[docs] @classmethod def from_cache(cls, root, name, _legacy=True): """Use this constructor to avoid initialization of original dataset which can be useful if only the cached zip file is available or to avoid expensive constructors of datasets.""" dataset = _CacheDataset(root, name, _legacy) return cls(dataset, _legacy=_legacy)
def __getstate__(self): """Close file before pickling.""" if hasattr(self, "zip"): self.zip.close() self.zip = None self.currentpid = None return self.__dict__ @property def fork_safe_zip(self): if self._legacy: currentpid = os.getpid() if getattr(self, "_initpid", None) != currentpid: self._initpid = currentpid self.zip = ZipFile(self.store_path, "r") return self.zip return ExamplesFolder(self.store_path)
[docs] def cache_dataset(self): """Checks if a dataset is stored. If not iterates over all possible indices and stores the examples in a file, as well as the labels.""" if not os.path.isfile(self.store_path) or self.force_cache: print("Caching {}".format(self.store_path)) manager = make_server_manager() inqueue = manager.get_inqueue() outqueue = manager.get_outqueue() N_examples = len(self.base_dataset) indices = np.arange(N_examples) if self.keep_existing and os.path.isfile(self.store_path): with ZipFile(self.store_path, "r") as zip_f: zipfilenames = zip_f.namelist() zipfilenames = set(zipfilenames) indices = [ i for i in indices if not self.naming_template.format(i) in zipfilenames ] print("Keeping {} cached examples.".format(N_examples - len(indices))) N_examples = len(indices) print("Caching {} examples.".format(N_examples)) index_chunks = [ indices[i : i + self.chunk_size] for i in range(0, len(indices), self.chunk_size) ] for chunk in index_chunks: inqueue.put(chunk) print("Waiting for results.") pbar = tqdm(total=N_examples) mode = "a" if self.keep_existing else "w" with ZipFile(self.store_path, mode, ZIP_DEFLATED) as self.zip: done_count = 0 while True: if done_count == N_examples: break pickle_name, pickle_bytes = outqueue.get() self.zip.writestr(pickle_name, pickle_bytes) pbar.update(1) done_count += 1 # after everything is done, we store memory keys seperately for # more efficient access # Note that this is always called, in case one wants to add labels # after caching has finished. This will add a new file with the # same name to the zip and it is currently not possible to delete # the old one. Preliminary tests have shown that the read method # returns the newest file if multiple ones are available but this # is _not_ documented or guaranteed in the API. If you experience # problems, try to write a new zip file with desired contents or # delete cached zip and cache again. memory_dict = dict() if hasattr(self.base_dataset, "in_memory_keys"): print("Caching Labels.") memory_keys = self.base_dataset.in_memory_keys for key in memory_keys: memory_dict[key] = list() for idx in trange(len(self.base_dataset)): example = self[idx] # load cached version # extract keys for key in memory_keys: memory_dict[key].append(example[key]) with ZipFile(self.store_path, "a", ZIP_DEFLATED) as zipfile: zipfile.writestr(self._labels_name, pickle.dumps(memory_dict)) print("Finished caching.")
def __len__(self): """Number of examples in this Dataset.""" return len(self.base_dataset) @property def labels(self): """Returns the labels associated with the base dataset, but from the cached source.""" if not hasattr(self, "_labels"): labels = self.fork_safe_zip.read(self._labels_name) labels = pickle.loads(labels) self._labels = labels return self._labels @property def root(self): """Returns the root to the base dataset.""" return self._root
[docs] def get_example(self, i): """Given an index i, returns a example.""" example_name = self.naming_template.format(i) example_file = self.fork_safe_zip.read(example_name) example = pickle.loads(example_file) return example
[docs]class PathCachedDataset(CachedDataset): """Used for simplified decorator interface to dataset caching."""
[docs] def __init__(self, dataset, path): self.force_cache = False self.keep_existing = True self.base_dataset = dataset self.store_dir = os.path.split(path)[0] self.store_path = path self.naming_template = "example_{}.p" self._labels_name = "labels.p" os.makedirs(self.store_dir, exist_ok=True) self.lenfile = self.store_path + ".p" if not os.path.exists(self.lenfile): self.force_cache = True self.cache_dataset() if not os.path.exists(self.lenfile): with open(self.lenfile, "wb") as f: pickle.dump(len(self.base_dataset), f)
def __len__(self): if not (self.base_dataset is None or os.path.exists(self.lenfile)): return len(self.base_dataset) if not hasattr(self, "_len"): with open(self.lenfile, "rb") as f: self._len = pickle.load(f) return self._len
[docs]def cachable(path): """Decorator to cache datasets. If not cached, will start a caching server, subsequent calls will just load from cache. Currently all worker must be able to see the path. Be careful, function parameters are ignored on furture calls. Can be used on any callable that returns a dataset. Currently the path should be the path to a zip file to cache into - i.e. it should end in zip. """ def decorator(fn): def wrapped(*args, **kwargs): if os.path.exists(path + ".p"): # cached version ready return PathCachedDataset(None, path) elif os.path.exists(path + "parameters.p"): # zip exists but not pickle with length - caching server # started and we are a worker bee with open(path + "parameters.p", "rb") as f: args, kwargs = pickle.load(f) return fn(*args, **kwargs) else: # start caching server dataset = fn(*args, **kwargs) os.makedirs(os.path.split(path)[0], exist_ok=True) with open(path + "parameters.p", "wb") as f: pickle.dump((args, kwargs), f) return PathCachedDataset(dataset, path) return wrapped return decorator