Source code for autogluon.mxnet.task.dataset
import logging
import math
import os
import platform
import sys
import warnings
import numpy as np
from PIL import Image
from mxnet import gluon, nd
from mxnet import recordio
from mxnet.gluon.data import RecordFileDataset
from mxnet.gluon.data.vision import ImageFolderDataset as MXImageFolderDataset
from mxnet.gluon.data.vision import ImageRecordDataset, transforms
from autogluon.core import *
from ..utils import get_data_rec
from ..utils.pil_transforms import *
_is_osx = platform.system() == "Darwin"
__all__ = [
'get_dataset',
'get_built_in_dataset',
'ImageFolderDataset',
'RecordDataset',
'NativeImageFolderDataset']
logger = logging.getLogger(__name__)
built_in_datasets = [
'mnist',
'cifar',
'cifar10',
'cifar100',
'imagenet',
'fashionmnist',
]
class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn
def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)
def generate_transform(train, resize, _is_osx, input_size, jitter_param):
if _is_osx:
# using PIL to load image (slow)
if train:
transform = Compose(
[
RandomResizedCrop(input_size),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
else:
transform = Compose(
[
Resize(resize),
CenterCrop(input_size),
ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
else:
if train:
transform = transforms.Compose(
[
transforms.RandomResizedCrop(input_size),
transforms.RandomFlipLeftRight(),
transforms.RandomColorJitter(
brightness=jitter_param,
contrast=jitter_param,
saturation=jitter_param
),
transforms.RandomLighting(0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
else:
transform = transforms.Compose(
[
transforms.Resize(resize),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
return transform
[docs]@func()
def get_dataset(path=None, train=True, name=None,
input_size=224, crop_ratio=0.875, jitter_param=0.4, scale_ratio_choice=[],
*args, **kwargs):
""" Method to produce image classification dataset for AutoGluon, can either be a
:class:`ImageFolderDataset`, :class:`RecordDataset`, or a
popular dataset already built into AutoGluon ('mnist', 'cifar10', 'cifar100', 'imagenet').
Parameters
----------
name : str, optional
Which built-in dataset to use, will override all other options if specified.
The options are ('mnist', 'cifar', 'cifar10', 'cifar100', 'imagenet')
train : bool, default = True
Whether this dataset should be used for training or validation.
path : str
The training data location. If using :class:`ImageFolderDataset`,
image folder`path/to/the/folder` should be provided.
If using :class:`RecordDataset`, the `path/to/*.rec` should be provided.
input_size : int
The input image size.
crop_ratio : float
Center crop ratio (for evaluation only)
scale_ratio_choice: list
List of crop_ratio, only in the test dataset, the set of scaling ratios obtained is scaled to the original image, and then cut a fixed size (input_size) and get a set of predictions for averaging.
Returns
-------
Dataset object that can be passed to `task.fit()`, which is actually an :class:`autogluon.space.AutoGluonObject`.
To interact with such an object yourself, you must first call `Dataset.init()` to instantiate the object in Python.
"""
resize = int(math.ceil(input_size / crop_ratio))
transform = generate_transform(train, resize, _is_osx, input_size, jitter_param)
if isinstance(name, str) and name.lower() in built_in_datasets:
return get_built_in_dataset(name, train=train, input_size=input_size, *args, **kwargs)
if '.rec' in path:
dataset = RecordDataset(
path,
*args,
transform=_TransformFirstClosure(transform),
**kwargs
)
elif _is_osx:
dataset = ImageFolderDataset(path, transform=transform, *args, **kwargs)
elif not train:
if not scale_ratio_choice:
dataset = TestImageFolderDataset(
path,
*args,
transform=_TransformFirstClosure(transform),
**kwargs
)
else:
dataset = []
for i in scale_ratio_choice:
resize = int(math.ceil(input_size / i))
dataset_item = TestImageFolderDataset(
path,
*args,
transform=_TransformFirstClosure(
generate_transform(train, resize, _is_osx, input_size, jitter_param)
),
**kwargs
)
dataset.append(dataset_item.init())
elif 'label_file' in kwargs:
dataset = IndexImageDataset(
path,
transform=_TransformFirstClosure(transform),
*args,
**kwargs
)
else:
dataset = NativeImageFolderDataset(
path,
*args,
transform=_TransformFirstClosure(transform),
**kwargs
)
if not scale_ratio_choice:
dataset = dataset.init()
return dataset
@obj()
class IndexImageDataset(MXImageFolderDataset):
"""A image classification dataset with a CVS label file
Each sample is an image and its corresponding label.
Parameters
----------
root : str
Path to the image folder.
indexfile : str
Local path to the csv index file. The CSV should have two collums
1. image name (e.g. xxxx or xxxx.jpg)
2. label name or index (e.g. aaa or 1)
gray_scale : False
If True, always convert images to greyscale.
If False, always convert images to colored (RGB).
transform : function, default None
A user defined callback that transforms each sample.
"""
def __init__(self, root, label_file, gray_scale=False, transform=None,
extension='.jpg'):
self._root = os.path.expanduser(root)
self.items, self.synsets = self.read_csv(label_file, root, extension)
self._flag = 0 if gray_scale else 1
self._transform = transform
@staticmethod
def read_csv(filename, root, extension):
"""The CSV should have two collums
1. image name (e.g. xxxx or xxxx.jpg)
2. label name or index (e.g. aaa or 1)
"""
def label_to_index(label_list, name):
return label_list.index(name)
import csv
label_dict = {}
with open(filename) as f:
reader = csv.reader(f)
for row in reader:
assert len(row) == 2
label_dict[row[0]] = row[1]
if 'id' in label_dict:
label_dict.pop('id')
labels = list(set(label_dict.values()))
samples = [
(os.path.join(root, f"{k}{extension}"), label_to_index(labels, v))
for k, v in label_dict.items()
]
return samples, labels
@property
def num_classes(self):
return len(self.synsets)
@property
def classes(self):
return self.synsets
@property
def num_classes(self):
return len(self.synsets)
@property
def classes(self):
return self.synsets
@obj()
class RecordDataset:
"""A dataset wrapping over a RecordIO file containing images.
Each sample is an image and its corresponding label.
Parameters
----------
filename : str
Local path to the .rec file.
gray_scale : False
If True, always convert images to greyscale.
If False, always convert images to colored (RGB).
transform : function, default None
A user defined callback that transforms each sample.
classes : iterable of str, default is None
User provided class names. If `None` is provide, will use
a list of increasing natural number ['0', '1', ..., 'N'] by default.
"""
def __init__(self, filename, gray_scale=False, transform=None, classes=None):
flag = 0 if gray_scale else 1
# retrieve number of classes without decoding images
td = RecordFileDataset(filename)
s = set([recordio.unpack(td.__getitem__(i))[0].label[0] for i in range(len(td))])
self._num_classes = len(s)
if not classes:
self._classes = [str(i) for i in range(self._num_classes)]
else:
if len(self._num_classes) != len(classes):
warnings.warn('Provided class names do not match data, expected "num_class" is {} '
'vs. provided: {}'.format(self._num_classes, len(classes)))
self._classes = list(classes) + \
[str(i) for i in range(len(classes), self._num_classes)]
self._dataset = ImageRecordDataset(filename, flag=flag)
if transform:
self._dataset = self._dataset.transform_first(transform)
@property
def num_classes(self):
return self._num_classes
@property
def classes(self):
return self._classes
def __len__(self):
return len(self._dataset)
def __getitem__(self, idx):
return self._dataset[idx]
@obj()
class NativeImageFolderDataset(MXImageFolderDataset):
def __init__(self, root, gray_scale=False, transform=None):
flag = 0 if gray_scale else 1
super().__init__(root, flag=flag, transform=transform)
@property
def num_classes(self):
return len(self.synsets)
@property
def classes(self):
return self.synsets
@obj()
class TestImageFolderDataset(MXImageFolderDataset):
def __init__(self, root, gray_scale=False, transform=None):
flag = 0 if gray_scale else 1
super().__init__(root, flag=flag, transform=transform)
def _list_images(self, root):
self.synsets = []
self.items = []
path = os.path.expanduser(root)
if not os.path.isdir(path):
raise ValueError('Ignoring %s, which is not a directory.' % path, stacklevel=3)
for filename in sorted(os.listdir(path)):
filename = os.path.join(path, filename)
if os.path.isfile(filename): # add
label = len(self.synsets)
ext = os.path.splitext(filename)[1]
if ext.lower() not in self._exts:
warnings.warn(
f'Ignoring {filename} of type {ext}.'
f' Only support {", ".join(self._exts)}'
)
continue
self.items.append((filename, label))
else:
folder = filename
if not os.path.isdir(folder):
raise ValueError(f'Ignoring {path}, which is not a directory.', stacklevel=3)
label = len(self.synsets)
for sub_filename in sorted(os.listdir(folder)):
sub_filename = os.path.join(folder, sub_filename)
ext = os.path.splitext(sub_filename)[1]
if ext.lower() not in self._exts:
warnings.warn(
f'Ignoring {sub_filename} of type {ext}.'
f' Only support {", ".join(self._exts)}'
)
continue
self.items.append((sub_filename, label))
self.synsets.append(label)
@property
def num_classes(self):
return len(self.synsets)
@property
def classes(self):
return self.synsets
@obj()
class ImageFolderDataset(object):
"""A generic data loader where the images are arranged in this way on your local filesystem: ::
root/dog/a.png
root/dog/b.png
root/dog/c.png
root/cat/x.png
root/cat/y.png
root/cat/z.png
Here, folder-names `dog` and `cat` are the class labels and the images with file-names 'a', `b`, `c` belong to the `dog` class while the others are `cat` images.
Parameters
----------
root : string
Root directory path to the folder containing all of the data.
transform : callable (optional)
A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
is_valid_file : callable (optional)
A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes
----------
classes : list
List of the class names.
class_to_idx : dict
Dict with items (class_name, class_index).
imgs : list
List of (image path, class_index) tuples
"""
_repr_indent = 4
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def __init__(self, root, extensions=None, transform=None, is_valid_file=None):
root = os.path.expanduser(root)
self.root = root
extensions = extensions if extensions else self.IMG_EXTENSIONS
self._transform = transform
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
raise RuntimeError(
f"Found 0 files in subfolders of: {self.root} "
f"\nSupported extensions are: {','.join(extensions)}"
)
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
self.imgs = self.samples
@staticmethod
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
if not x.lower().endswith(extensions):
return False
valid = True
try:
with open(x, 'rb') as f:
Image.open(f)
except OSError:
valid = False
return valid
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.abspath(os.path.join(root, fname))
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
if not class_to_idx:
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
path = os.path.abspath(os.path.join(root, fname))
if is_valid_file(path):
item = (path, 0)
images.append(item)
return images
@staticmethod
def loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
@staticmethod
def _find_classes(dir):
"""Finds the class folders in a dataset.
Parameters
----------
dir : string
Root directory path.
Returns
-------
tuple: (classes, class_to_idx)
where classes are relative to (dir), and class_to_idx is a dictionary.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
@property
def num_classes(self):
return len(self.classes)
def __getitem__(self, index):
"""
Parameters
----------
index : int
Index
Returns
----------
tuple : (sample, target)
where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self._transform is not None:
sample = self._transform(sample)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
def get_built_in_dataset(name, train=True, input_size=224, batch_size=256, num_workers=32,
shuffle=True, **kwargs):
"""Returns built-in popular image classification dataset based on provided string name ('cifar10', 'cifar100','mnist','imagenet').
"""
logger.info(f'get_built_in_dataset {name}')
name = name.lower()
if name in ('cifar10', 'cifar'):
import gluoncv.data.transforms as gcv_transforms
if train:
transform_split = transforms.Compose(
[
gcv_transforms.RandomCrop(32, pad=4),
transforms.RandomFlipLeftRight(),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]
)
else:
transform_split = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]
)
return gluon.data.vision.CIFAR10(train=train).transform_first(transform_split)
elif name == 'cifar100':
import gluoncv.data.transforms as gcv_transforms
if train:
transform_split = transforms.Compose(
[
gcv_transforms.RandomCrop(32, pad=4),
transforms.RandomFlipLeftRight(),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]
)
else:
transform_split = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]
)
return gluon.data.vision.CIFAR100(train=train).transform_first(transform_split)
elif name == 'mnist':
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2, 0, 1)) / 255, label.astype(np.float32)
return gluon.data.vision.MNIST(train=train, transform=transform)
elif name == 'fashionmnist':
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2, 0, 1)) / 255, label.astype(np.float32)
return gluon.data.vision.FashionMNIST(train=train, transform=transform)
elif name == 'imagenet':
# Please setup the ImageNet dataset following the tutorial from GluonCV
if train:
rec_file = '/media/ramdisk/rec/train.rec'
rec_file_idx = '/media/ramdisk/rec/train.idx'
else:
rec_file = '/media/ramdisk/rec/val.rec'
rec_file_idx = '/media/ramdisk/rec/val.idx'
data_loader = get_data_rec(input_size, 0.875, rec_file, rec_file_idx,
batch_size, num_workers, train, shuffle=shuffle,
**kwargs)
return data_loader
else:
raise NotImplementedError