Source code for edflow.datasets.mnist

import sys, os, tarfile, pickle
from pathlib import Path
import numpy as np
from tqdm import tqdm, trange
import urllib
import gzip
import struct

import edflow.datasets.utils as edu


[docs]def read_mnist_file(path): # https://gist.github.com/tylerneylon/ce60e8a06e7506ac45788443f7269e40 with gzip.open(path, "rb") as f: zero, data_type, dims = struct.unpack(">HBB", f.read(4)) shape = tuple(struct.unpack(">I", f.read(4))[0] for d in range(dims)) return np.fromstring(f.read(), dtype=np.uint8).reshape(shape)
[docs]class MNIST(edu.DatasetMixin): NAME = "MNIST" # CVDF mirror of http://yann.lecun.com/exdb/mnist/ URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" FILES = dict( TRAIN_DATA="train-images-idx3-ubyte.gz", TRAIN_LABELS="train-labels-idx1-ubyte.gz", TEST_DATA="t10k-images-idx3-ubyte.gz", TEST_LABELS="t10k-labels-idx1-ubyte.gz", )
[docs] def __init__(self, config=None): self.config = config or dict() self.logger = edu.get_logger(self) self._prepare() self._load()
def _prepare(self): self.root = edu.get_root(self.NAME) self._data_path = Path(self.root).joinpath("data.p") if not edu.is_prepared(self.root): # prep self.logger.info("Preparing dataset {} in {}".format(self.NAME, self.root)) root = Path(self.root) urls = dict( (v, urllib.parse.urljoin(self.URL, v)) for k, v in self.FILES.items() ) local_files = edu.download_urls(urls, target_dir=root) data = dict() for k, v in local_files.items(): data[k] = read_mnist_file(v) with open(self._data_path, "wb") as f: pickle.dump(data, f) edu.mark_prepared(self.root) def _get_split(self): split = ( "test" if self.config.get("test_mode", False) else "train" ) # default split if self.NAME in self.config: split = self.config[self.NAME].get("split", split) return split def _load(self): with open(self._data_path, "rb") as f: self._data = pickle.load(f) split = self._get_split() assert split in ["train", "test"] self.logger.info("Using split: {}".format(split)) if split == "test": self._images = self._data[self.FILES["TEST_DATA"]] self._data_labels = self._data[self.FILES["TEST_LABELS"]] else: self._images = self._data[self.FILES["TRAIN_DATA"]] self._data_labels = self._data[self.FILES["TRAIN_LABELS"]] self.labels = {"class": self._data_labels, "image": self._images} self._length = self._data_labels.shape[0] def _load_example(self, i): example = dict() for k in self.labels: example[k] = self.labels[k][i] return example def _preprocess_example(self, example): example["image"] = example["image"] / 127.5 - 1.0 example["image"] = example["image"][:, :, None].astype(np.float32)
[docs] def get_example(self, i): example = self._load_example(i) self._preprocess_example(example) return example
def __len__(self): return self._length
[docs]class MNISTTrain(MNIST): def _get_split(self): return "train"
[docs]class MNISTTest(MNIST): def _get_split(self): return "test"
if __name__ == "__main__": print("train") d = MNIST() print(len(d)) e = d[0] x, y = e["image"], e["class"] print(x.dtype, x.shape, x.min(), x.max(), y) print("test") d = MNIST({"FashionMNIST": {"split": "test"}}) print(len(d)) e = d[0] x, y = e["image"], e["class"] print(x.dtype, x.shape, x.min(), x.max(), y)