Source code for edflow.data.processing.labels

from edflow.data.dataset_mixin import DatasetMixin
import numpy as np


[docs]class LabelDataset(DatasetMixin): """A label only dataset to avoid loading unnecessary data."""
[docs] def __init__(self, data): """ Parameters ---------- data : DatasetMixin Some dataset where we are only interested in the labels. """ self.data = data self.keys = sorted(self.data.labels.keys())
[docs] def get_example(self, i): """Return only labels of example.""" example = dict((k, self.data.labels[k][i]) for k in self.keys) example["base_index_"] = i return example
[docs]class ExtraLabelsDataset(DatasetMixin): """A dataset with extra labels added."""
[docs] def __init__(self, data, labeler): """ Parameters ---------- data : DatasetMixin Some Base dataset you want to add labels to labeler : Callable Must accept two arguments: a ``Dataset`` and an index ``i`` and return a dictionary of labels to add or overwrite. For all indices the keys in the returned ``dict`` must be the same and the type and shape of the values at those keys must be the same per key. """ self.data = data self._labeler = labeler self._new_keys = sorted(self._labeler(self.data, 0).keys()) self._new_labels = dict() for k in self._new_keys: self._new_labels[k] = [None for _ in range(len(self.data))] for i in range(len(self.data)): new_labels = self._labeler(self.data, i) for k in self._new_keys: self._new_labels[k][i] = new_labels[k] self._labels = dict(self.data.labels) self._labels.update(self._new_labels) labels = {} for k, v in self._labels.items(): labels[k] = np.array(v) self._labels = labels self.append_labels = True
@property def labels(self): return self._labels