Source code for agnapprox.datamodules.mnist

"""
Wrapper for MNIST dataset
"""
import torch.utils.data as td
from torchvision import datasets, transforms

from .approx_datamodule import ApproxDataModule


[docs]class MNIST(ApproxDataModule): """ Dataloader instance for the MNIST dataset """ def __init__(self, **kwargs): super().__init__(**kwargs) @property def normalize(self): """ Default MNIST normalization pipeline Returns: List of transformations to apply to input image """ return [ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]
[docs] def prepare_data(self): datasets.MNIST(root=self.data_dir, train=True, download=True) datasets.MNIST(root=self.data_dir, train=False, download=True)
[docs] def setup(self, stage=None): if stage == "fit" or stage is None: target_transform = transforms.Compose(self.normalize) mnist_full = datasets.MNIST( root=self.data_dir, train=True, transform=target_transform ) self.df_train, self.df_val = td.random_split(mnist_full, [55000, 5000]) if stage == "test" or stage is None: target_transform = transforms.Compose(self.normalize) self.df_test = datasets.MNIST( root=self.data_dir, train=False, transform=target_transform )