Source code for draugr.torch_utilities.datasets.supervised.vision_datasets.dict_dataset

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 21/07/2020
           """

import random
from pathlib import Path
from typing import Iterable, Tuple

from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset  # TODO: Do not need to be images

__all__ = ["DictDatasetFolder", "SplitDictDatasetFolder"]

from torchvision.datasets.folder import has_file_allowed_extension

from draugr.numpy_utilities import (
    SplitEnum,
    build_flat_dataset,
    build_shallow_categorical_dataset,
    select_split,
)


[docs]class SplitDictDatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Attributes: _categories (list): List of the class names sorted alphabetically. _data_categories (list): List of (sample path, class_index) tuples"""
[docs] def __init__( self, root: Path, loader: DataLoader, extensions: Iterable = None, transform: callable = None, target_transform: callable = None, split: SplitEnum = SplitEnum.training, valid_percentage: float = 15, test_percentage: float = 0, is_valid_file: callable = has_file_allowed_extension, ): super().__init__( str(root), transform=transform, target_transform=target_transform ) # TODO: merge Split and non split common in a base class self._data_categories = select_split( build_shallow_categorical_dataset( self.root, extensions=extensions, testing_percentage=test_percentage, validation_percentage=valid_percentage, ), split, ) if len(self._data_categories) == 0: msg = f"Found 0 categories in sub-folders of: {self.root}\n" if extensions is not None: msg += f"Supported extensions are: {','.join(extensions)}" raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.category_sizes = {k: len(v) for k, v in self._data_categories.items()} for cat, cl in self.category_sizes.items(): if cl == 0: print(f"Warning category {cat} has {cl} samples") self.category_names = (*self.category_sizes.keys(),)
def __getitem__(self, index) -> Tuple: """ Non-pure implementation! Index maybe not map to the same item as target randomly sampled Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class.""" target = random.choice(self.category_names) return self.sample(target, index)
[docs] def sample(self, target, index) -> Tuple: """description""" sample = self.loader( self._data_categories[target][index % self.category_sizes[target]] ) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target
def __len__(self): return sum(list(self.category_sizes.values()))
[docs]class DictDatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Attributes: _categories (list): List of the class names sorted alphabetically. _data (list): List of (sample path, class_index) tuples"""
[docs] def __init__( self, root: Path, loader: DataLoader, extensions: Iterable = None, transform: callable = None, target_transform: callable = None, is_valid_file: callable = has_file_allowed_extension, ): super().__init__( str(root), transform=transform, target_transform=target_transform ) # TODO: merge Split and non split common in a base class self._data = build_flat_dataset( self.root, extensions=extensions, is_valid_file=is_valid_file ) if len(self._data) == 0: msg = f"Found 0 files in sub-folders of: {self.root}\n" if extensions is not None: msg += f"Supported extensions are: {','.join(extensions)}" raise RuntimeError(msg) self.loader = loader self.extensions = extensions self.category_sizes = {k: len(v) for k, v in self._data.items()} self.category_names = (*self.category_sizes.keys(),)
def __getitem__(self, index) -> Tuple: """ Non-pure implementation! Index maybe not map to the same item as target randomly sampled Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class.""" target = random.choice(self.category_names) return self.sample(target, index)
[docs] def sample(self, target, index) -> Tuple: """description""" sample = self.loader(self._data[target][index % self.category_sizes[target]]) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target
def __len__(self): return sum(list(self.category_sizes.values()))
if __name__ == "__main__": pass