Source code for edflow.datasets.cifar

import sys, os, tarfile, pickle
from pathlib import Path
import numpy as np
from tqdm import tqdm, trange
import urllib

import edflow.datasets.utils as edu


[docs]class CIFAR10(edu.DatasetMixin): NAME = "CIFAR10" URL = "https://www.cs.toronto.edu/~kriz/" FILES = dict(DATA="cifar-10-python.tar.gz")
[docs] def __init__(self, config=None): self.config = config or dict() self.logger = edu.get_logger(self) self._prepare() self._load()
def _prepare(self): self.root = edu.get_root(self.NAME) self._data_path = Path(self.root).joinpath("data.p") if not edu.is_prepared(self.root): # prep self.logger.info("Preparing dataset {} in {}".format(self.NAME, self.root)) root = Path(self.root) urls = dict( (v, urllib.parse.urljoin(self.URL, v)) for k, v in self.FILES.items() ) local_files = edu.download_urls(urls, target_dir=root) edu.unpack(local_files["cifar-10-python.tar.gz"]) base = os.path.join(self.root, "cifar-10-batches-py") labels = list() filenames = list() datas = list() for batch_file in ["data_batch_{}".format(i) for i in range(1, 6)]: with open(os.path.join(base, batch_file), "rb") as f: batch_data = pickle.load(f, encoding="bytes") labels += batch_data["labels".encode()] filenames += [ fname.decode() for fname in batch_data["filenames".encode()] ] datas.append(batch_data["data".encode()]) with open(os.path.join(base, "test_batch"), "rb") as f: test_data = pickle.load(f, encoding="bytes") test_labels = test_data["labels".encode()] test_filenames = [ fname.decode() for fname in test_data["filenames".encode()] ] test_datas = test_data["data".encode()] with open(os.path.join(base, "batches.meta"), "rb") as f: _meta = pickle.load(f, encoding="bytes") meta = { "label_names": [ name.decode() for name in _meta["label_names".encode()] ], "num_vis": _meta["num_vis".encode()], "num_cases_per_batch": _meta["num_cases_per_batch".encode()], } # convert to (32,32,3) RGB uint8 images = np.concatenate(datas, axis=0) images = np.reshape(images, [-1, 3, 32, 32]) images = np.transpose(images, [0, 2, 3, 1]) test_images = test_datas test_images = np.reshape(test_images, [-1, 3, 32, 32]) test_images = np.transpose(test_images, [0, 2, 3, 1]) filenames = np.array(filenames) test_filenames = np.array(test_filenames) labels = np.array(labels) test_labels = np.array(test_labels) data = { "train": dict(images=images, filenames=filenames, labels=labels), "test": dict( images=test_images, filenames=test_filenames, labels=test_labels ), "meta": meta, } with open(self._data_path, "wb") as f: pickle.dump(data, f) edu.mark_prepared(self.root)
[docs] def get_split(self): split = ( "test" if self.config.get("test_mode", False) else "train" ) # default split if self.NAME in self.config: split = self.config[self.NAME].get("split", split) return split
def _load(self): with open(self._data_path, "rb") as f: self._data = pickle.load(f) split = self._get_split() assert split in ["train", "test"] self.logger.info("Using split: {}".format(split)) self.labels = { "class": self._data[split]["labels"], "image": self._data[split]["images"], "filename": self._data[split]["filenames"], } self._length = self.labels["class"].shape[0] def _load_example(self, i): example = dict() for k in self.labels: example[k] = self.labels[k][i] return example def _preprocess_example(self, example): example["image"] = example["image"] / 127.5 - 1.0 example["image"] = example["image"].astype(np.float32)
[docs] def get_example(self, i): example = self._load_example(i) self._preprocess_example(example) return example
def __len__(self): return self._length
[docs]class CIFAR10Train(CIFAR10): def _get_split(self): return "train"
[docs]class CIFAR10Test(CIFAR10): def _get_split(self): return "test"
if __name__ == "__main__": from edflow.util import pp2mkdtable print("train") d = CIFAR10() print(len(d)) e = d[0] print(pp2mkdtable(e)) x, y = e["image"], e["class"] print(x.dtype, x.shape, x.min(), x.max(), y) print("test") d = CIFAR10({"CIFAR10": {"split": "test"}}) print(len(d)) from PIL import Image Image.fromarray(((x + 1.0) * 127.5).astype(np.uint8)).save("cifar10_example.png")