Source code for edflow.data.util.util_dsets

import numpy as np
from edflow.data.dataset_mixin import DatasetMixin
from edflow.util import PRNGMixin
from edflow.util import retrieve
from edflow.main import get_obj_from_str


[docs]def JoinedDataset(dataset, key, n_joins): """Concat n_joins random samples based on the condition that example_i[key] == example_j[key] for all i,j. Key must be in labels of dataset.""" labels = np.asarray(dataset.labels[key]) unique_labels = np.unique(labels) index_map = dict() for value in unique_labels: index_map[value] = np.nonzero(labels == value)[0] join_indices = [list(range(len(dataset)))] # example_0 is original example prng = np.random.RandomState(1) for k in range(n_joins - 1): indices = [prng.choice(index_map[value]) for value in labels] join_indices.append(indices) datasets = [SubDataset(dataset, indices) for indices in join_indices] dataset = ExampleConcatenatedDataset(*datasets) return dataset
[docs]def getDebugDataset(config): """Loads a dataset from the config and makes ist reasonably small. The config syntax works as in :func:`getSeqDataset`. See there for more extensive documentation. Parameters ---------- config : dict An edflow config, with at least the keys ``debugdataset`` and nested inside it ``dataset``, ``debug_length``, defining the basedataset and its size. Returns ------- :class:`SubDataset`: A dataset based on the basedataset of the specifed length. """ ks = "debugdataset" base_dset = get_implementations_from_config(config[ks], ["dataset"])["dataset"] base_dset = base_dset(config=config) indices = np.arange(config[ks]["debug_length"]) return SubDataset(base_dset, indices)
[docs]class RandomlyJoinedDataset(DatasetMixin, PRNGMixin): """ Load multiple examples which have the same label. Required config parameters: :RandomlyJoinedDataset/dataset: The dataset from which to load examples. :RandomlyJoinedDataset/key: The key of the label to join on. Optional config parameters: :test_mode=False: If True, behaves deterministic. :RandomlyJoinedDataset/n_joins=2: How many examples to load. :RandomlyJoinedDataset/balance=False: If True and not in test_mode, sample join labels uniformly. :RandomlyJoinedDataset/avoid_identity=True: If True and not in test_mode, never return a pair containing the same image if possible. The i-th example returns: :'examples': A list of examples, where each example has the same label as specified by key. If data_balancing is `False`, the first element of the list will be the `i-th` example of the dataset. The dataset's labels are the same as that of dataset. Be careful, `examples[j]` of the i-th example does not correspond to the i-th entry of the labels but to the `examples[j]["index_"]`-th entry. """
[docs] def __init__(self, config): self.dataset = retrieve(config, "RandomlyJoinedDataset/dataset") self.dataset = get_obj_from_str(self.dataset) self.dataset = self.dataset(config) self.key = retrieve(config, "RandomlyJoinedDataset/key") self.n_joins = retrieve(config, "RandomlyJoinedDataset/n_joins", default=2) self.test_mode = retrieve(config, "test_mode", default=False) self.avoid_identity = retrieve( config, "RandomlyJoinedDataset/avoid_identity", default=True ) self.balance = retrieve(config, "RandomlyJoinedDataset/balance", default=False) # self.index_map is used to select a partner for each example. # In test_mode it is a list containing a single partner index for each # example, otherwise it is a dict containing all indices for a given # join label self.join_labels = np.asarray(self.dataset.labels[self.key]) unique_labels = np.unique(self.join_labels) self.index_map = dict() for value in unique_labels: self.index_map[value] = np.nonzero(self.join_labels == value)[0] if self.test_mode: prng = np.random.RandomState(0) self.index_map = [ prng.choice(self.index_map[self.join_labels[i]], self.n_joins - 1) for i in range(len(self.dataset)) ]
def __len__(self): return len(self.dataset) @property def labels(self): """Careful this can only give labels of the original item, not the joined ones. Use 'examples[j]["index\_"]' to get the correct label index.""" return self.dataset.labels
[docs] def get_example(self, i): if self.test_mode: join_indices = self.index_map[i] else: if self.balance: label_id = self.prng.choice(list(self.index_map.keys())) i = self.prng.choice(self.index_map[label_id]) join_value = self.join_labels[i] choices = self.index_map[join_value] replace = True if self.avoid_identity: if len(choices) > 1: choices = [idx for idx in choices if not idx == i] if len(choices) >= self.n_joins - 1: replace = False join_indices = self.prng.choice(choices, self.n_joins - 1, replace=replace) join_indices = np.concatenate([[i], join_indices]) return {"examples": self.dataset[join_indices]}
[docs]class DataFolder(DatasetMixin): """Given the root of a possibly nested folder containing datafiles and a Callable that generates the labels to the datafile from its full name, this class creates a labeled dataset. A filtering of unwanted Data can be achieved by having the ``label_fn`` return ``None`` for those specific files. The actual files are only read when ``__getitem__`` is called. If for example ``label_fn`` returns a dict with the keys ``['a', 'b', 'c']`` and ``read_fn`` returns one with keys ``['d', 'e']`` then the dict returned by ``__getitem__`` will contain the keys ``['a', 'b', 'c', 'd', 'e', 'file_path_', 'index_']``. """
[docs] def __init__( self, image_root, read_fn, label_fn, sort_keys=None, in_memory_keys=None, legacy=True, show_bar=False, ): """ Parameters ---------- image_root : str Root containing the files of interest. read_fn : Callable Given the path to a file, returns the datum as a dict. label_fn : Callable Given the path to a file, returns a dict of labels. If ``label_fn`` returns ``None``, this file is ignored. sort_keys : list A hierarchy of keys by which the data in this Dataset are sorted. in_memory_keys : list keys which will be collected from examples when the dataset is cached. legacy : bool Use the old read ethod, where only the path to the current file is passed to the reader. The new version will see all labels, that have been previously collected. show_bar : bool Show a loading bar when loading labels. """ self.root = image_root self.read = read_fn self.label_fn = label_fn self.sort_keys = sort_keys self.legacy = legacy self.show_bar = show_bar if in_memory_keys is not None: assert isinstance(in_memory_keys, list) self.in_memory_keys = in_memory_keys self._read_labels()
def _read_labels(self): import operator if self.show_bar: n_files = 0 for _ in os.walk(self.root): n_files += 1 iterator = tqdm(os.walk(self.root), total=n_files, desc="Labels") else: iterator = tqdm(os.walk(self.root)) self.data = [] self.labels = {} for root, dirs, files in iterator: for f in files: path = os.path.join(root, f) labels = self.label_fn(path) if labels is not None: datum = {"file_path_": path} datum.update(labels) self.data += [datum] if self.sort_keys is not None: self.data.sort(key=operator.itemgetter(*self.sort_keys)) for datum in self.data: for k, v in datum.items(): if k not in self.labels: self.labels[k] = [] self.labels[k] += [v]
[docs] def get_example(self, i): """Load the files specified in example ``i``.""" datum = self.data[i] path = datum["file_path_"] if self.legacy: file_content = self.read(path) else: file_content = self.read(**datum) example = dict() example.update(datum) example.update(file_content) return example