Source code for dmme.datasets.lsun

from typing import Any, cast, Callable, List, Optional, Tuple, Union

import os.path as osp
import io
import pickle
import string

import PIL
from PIL import Image

from torchvision.datasets import VisionDataset


[docs]class LSUNClass(VisionDataset): """LSUNClass from torchvision Loads lmdb dataset with empty value checks. Args: root (str): directory containing mdb files transform (Transform): transforms to apply on data, Optional target_transform (Transform): transforms to apply on labels, Optional ignore_keys (List): list of keys to ignore in lmdb database, some keys have empty values """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ignore_keys: Optional[List] = None, ) -> None: import lmdb super().__init__(root, transform=transform, target_transform=target_transform) self.env = lmdb.open( root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False, ) with self.env.begin(write=False) as txn: self.length = txn.stat()["entries"] cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) if osp.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: print("creating cache, this may take a while...") with self.env.begin(write=False) as txn: if ignore_keys is not None: self.keys = [ key for key in txn.cursor().iternext(keys=True, values=False) ] for key in ignore_keys: self.keys.remove(key) else: self.keys = [] for key, value in txn.cursor(): self.keys.append(key) try: buf = io.BytesIO() buf.write(value) Image.open(buf) except PIL.UnidentifiedImageError: # skip invalid values print(f"skipped {key}") self.keys.pop() pickle.dump(self.keys, open(cache_file, "wb")) def __getitem__(self, index: int) -> Tuple[Any, Any]: img, target = None, None env = self.env with env.begin(write=False) as txn: imgbuf = txn.get(self.keys[index]) buf = io.BytesIO() buf.write(imgbuf) buf.seek(0) img = Image.open(buf).convert("RGB") if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.keys)
[docs]class LSUN(VisionDataset): """`LSUN <https://www.yf.io/p/lsun>`_ dataset. You will need to install the ``lmdb`` package to use this dataset: run ``pip install lmdb`` Args: root (string): Root directory for the database files. classes (string or list): One of {'train', 'val', 'test'} or a list of categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ scenes = [ "bedroom", "bridge", "church_outdoor", "classroom", "conference_room", "dining_room", "kitchen", "living_room", "restaurant", "test", "tower", ] objects = [ "airplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dining_table", "dog", "horse", "motorbike", "person", "potted_plant", "sheep", "sofa", "train", "tv-monitor", ] ignore_keys = { "cat": [ b"05c509a12295c0725be85566680c58c81965ea63", b"0ec91d487375c2663a43d463f9e5b4e34b8527aa", ] } def __init__( self, root: str, classes: Union[str, List[str]] = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) if classes in ["train", "val", "test"]: classes = cast(str, classes) if classes == "test": classes = [classes] else: classes = [c + "_" + classes for c in self.scenes] elif isinstance(classes, str): classes = [classes] self.classes = classes # for each class, create an LSUNClassDataset self.dbs = [] for c in self.classes: ignore_keys = None if c in self.objects: root = osp.join(root, c) ignore_keys = self.ignore_keys[c] else: root = osp.join(root, f"{c}_lmdb") self.dbs.append( LSUNClass(root=root, transform=transform, ignore_keys=ignore_keys) ) self.indices = [] count = 0 for db in self.dbs: count += len(db) self.indices.append(count) self.length = count def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: Tuple (image, target) where target is the index of the target category. """ target = 0 sub = 0 for ind in self.indices: if index < ind: break target += 1 sub = ind db = self.dbs[target] index = index - sub if self.target_transform is not None: target = self.target_transform(target) img, _ = db[index] return img, target def __len__(self) -> int: return self.length def extra_repr(self) -> str: return "Classes: {classes}".format(**self.__dict__)