Source code for dmme.data_modules.lsun

from typing import Callable, List, Optional

import subprocess
import os.path as osp

import torchvision.transforms as TF
from torch.utils.data import Dataset

import zipfile

from dmme.common import norm

from .data_module import DataModule
from .. import datasets


[docs]class LSUN(DataModule): """LSUN datamodule Args: data_dir: path to cifar10 batch_size: batch size class_name: lsun zipfile name to load, specify as list to load multiple classes. Set to :code:`"train"` to load all train scenes in lsun. Unsupported for objects. imgsize: dataset image size augs (List[Transform]): augmentations on PIL images only from torchvision. Set to :code:`None` to disable augmentations """ scenes = set( [ "bedroom", "bridge", "church_outdoor", "classroom", "conference_room", "dining_room", "kitchen", "living_room", "restaurant", "test", "tower", ] ) def __init__( self, data_dir: str = ".", batch_size: int = 128, class_name: str = "train", imgsize: int = 256, augs: Optional[List[Callable]] = [TF.RandomHorizontalFlip()], ): super().__init__(batch_size) self.data_dir = data_dir self.class_name = class_name self.imgsize = imgsize if augs is None: augs = [] self.augs = augs
[docs] def prepare_data(self): r"""Download and extract `LSUN <https://www.yf.io/p/lsun>`_ dataset LSUN is downloaded using `aria2 <https://aria2.github.io>`_ to speedup downloads. """ if self.class_name in ["train", "val", "test"]: for category in self.scenes: self.download_scenes(self.data_dir, category, set_name=self.class_name) self.classes = self.class_name else: category = "_".join(self.class_name.split("_")[:-1]) if category in self.scenes: set_name = self.class_name.split("_")[-1] self.download_scenes(self.data_dir, category, set_name) else: self.download_objects(self.data_dir, category=self.class_name) self.classes = self.class_name if self.class_name not in ["train", "val", "test"]: self.classes = [self.classes]
[docs] def download_scenes(self, out_dir, category: str, set_name: str): """Download lsun scenes data Args: out_dir: output directory category: category name to download. should match the category names in `scenes <http://dl.yf.io/lsun/scenes/>`_ set_name: either "train", "val", or "test" """ if set_name == "test": out_name = "test_lmdb.zip" else: out_name = f"{category}_{set_name}_lmdb.zip" url = f"http://dl.yf.io/lsun/scenes/{out_name}" self.download_url(url, out_dir, out_name)
[docs] def download_objects(self, out_dir, category: str): """Download lsun objects data Args: out_dir: output directory category: category name to download. should match the category names in `scenes <http://dl.yf.io/lsun/scenes/>`_ set_name: either "train", "val", or "test" """ out_name = f"{category}.zip" url = f"http://dl.yf.io/lsun/objects/{out_name}" self.download_url(url, out_dir, out_name)
[docs] def download_url(self, url: str, out_dir, out_name: str): """Download url and extract zip, will skip if file exists Args: url: url to download out_dir: output directory out_name: file name to download as """ lmdb_path = osp.join(out_dir, out_name.split(".")[0]) if osp.exists(lmdb_path): print("File exists skipping download") else: out_path = osp.join(out_dir, out_name) if not osp.exists(out_path): cmd = ["aria2c", "-x", "16", "-s", "16", url, "-o", out_path] print(f"Downloading {out_name}...") subprocess.call(cmd) print(f"Extracting {out_name}...") with zipfile.ZipFile(out_path) as f: f.extractall(out_dir)
def dataset(self, augs: List[Callable] = []) -> Dataset: classes = self.class_name if self.class_name not in ["train", "val", "test"]: if isinstance(self.class_name, str): classes = [self.class_name] return datasets.LSUN( root=self.data_dir, classes=classes, transform=TF.Compose( [ *augs, TF.Resize(size=self.imgsize), TF.CenterCrop(size=self.imgsize), TF.ToTensor(), norm, ] ), )
[docs] def setup_train(self): return self.dataset(self.augs)
[docs] def setup_test(self): return self.dataset()