Source code for autogluon.core.utils.serialization
"""Serilization and checkpoint"""
import logging
import difflib
import inspect
import os
import io
import struct
import sys
import tarfile
import zipfile
import tempfile
import warnings
from contextlib import closing, contextmanager
if sys.version_info[0] == 2:
    import cPickle as pickle
    _string_classes = basestring
else:
    import pickle
    import pathlib
    _string_classes = (str, bytes)
__all__ = ['save', 'load']
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
logger = logging.getLogger(__name__)
MAGIC_NUMBER = 0x7df059597099bb7dcf25
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
_package_registry = []
[docs]def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
    """Saves an object to local file.
    Parameters
    ----------
    obj : object
        Python object to save.
    f : string or file object
        A file-like object (has to implement write and flush), or a string containing a file name
    pickle_module : pickle
        Module used for pickling metadata and objects
    pickle_protocol : protocol (optional)
        Protocol can be specified to override the default pickle protocol.
    Examples
    --------
    >>> save(scheduler.state_dict(), checkname)
    """
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
[docs]def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
    """Loads an object saved with :func:`save` from file.
    Parameters
    ----------
        f (string or file object): a file-like object (has to implement write and flush)
            or a string containing a file name
    Examples
    --------
    >>> scheduler.load_state_dict(load(checkpoint))
    """
    new_fd = False
    if isinstance(f, str) or \
            (sys.version_info[0] == 2 and isinstance(f, unicode)):
        new_fd = True
        f = open(f, 'rb')
    elif (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
        new_fd = True
        f = f.open('rb')
    try:
        return _load(f, map_location, pickle_module, **pickle_load_args)
    finally:
        if new_fd:
            f.close()
def _with_file_like(f, mode, body):
    """
    Executes a body function with a file object for f, opening
    it in 'mode' if it is a string filename.
    """
    new_fd = False
    if isinstance(f, str) or \
            (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
            (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
        new_fd = True
        f = open(f, mode)
    try:
        return body(f)
    finally:
        if new_fd:
            f.close()
def register_package(priority, tagger, deserializer):
    queue_elem = (priority, tagger, deserializer)
    _package_registry.append(queue_elem)
    _package_registry.sort()
def _cpu_deserialize(obj, location):
    if location == 'cpu':
        return obj
def _cpu_tag(obj):
    if type(obj).__module__ == 'mxnet':
        return 'cpu'
register_package(10, _cpu_tag, _cpu_deserialize)
def default_restore_location(storage, location):
    for _, _, fn in _package_registry:
        result = fn(storage, location)
        if result is not None:
            return result
    raise RuntimeError("don't know how to restore data location")
def _is_compressed_file(f):
    compress_modules = ['gzip']
    try:
        return f.__module__ in compress_modules
    except AttributeError:
        return False
def _should_read_directly(f):
    """
    Checks if f is a file that should be read directly. It should be read
    directly if it is backed by a real file (has a fileno) and is not a
    a compressed file (e.g. gzip)
    """
    if _is_compressed_file(f):
        return False
    try:
        return f.fileno() >= 0
    except io.UnsupportedOperation:
        return False
    except AttributeError:
        return False
def _check_seekable(f):
    def raise_err_msg(patterns, e):
        for p in patterns:
            if p in str(e):
                msg = (str(e) + ". You can only load from a file that is seekable." +
                                " Please pre-load the data into a buffer like io.BytesIO and" +
                                " try to load from it instead.")
                raise type(e)(msg)
        raise e
    try:
        f.seek(f.tell())
        return True
    except (io.UnsupportedOperation, AttributeError) as e:
        raise_err_msg(["seek", "tell"], e)
def _save(obj, f, pickle_module, pickle_protocol):
    if sys.version_info[0] == 2:
        import StringIO
        if isinstance(f, StringIO.StringIO):
            msg = ('save received unsupported StringIO.StringIO file object, whose '
                   'write method does not return the number of bytes written. '
                   'Please use something like io.BytesIO for save instead.')
            logger.error(msg)
            raise RuntimeError(msg)
    serialized_container_types = {}
    serialized_storages = {}
    sys_info = dict(
        protocol_version=PROTOCOL_VERSION,
        little_endian=sys.byteorder == 'little',
        type_sizes=dict(
            short=SHORT_SIZE,
            int=INT_SIZE,
            long=LONG_SIZE,
        ),
    )
    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    pickler.dump(obj)
    serialized_storage_keys = sorted(serialized_storages.keys())
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
    f.flush()
    for key in serialized_storage_keys:
        serialized_storages[key]._write_file(f, _should_read_directly(f))
def _load(f, map_location, pickle_module, **pickle_load_args):
    deserialized_objects = {}
    if map_location is None:
        restore_location = default_restore_location
    elif isinstance(map_location, dict):
        def restore_location(storage, location):
            location = map_location.get(location, location)
            return default_restore_location(storage, location)
    elif isinstance(map_location, _string_classes):
        def restore_location(storage, location):
            return default_restore_location(storage, map_location)
    else:
        def restore_location(storage, location):
            result = map_location(storage, location)
            if result is None:
                result = default_restore_location(storage, location)
            return result
    def _check_container_source(container_type, source_file, original_source):
        try:
            current_source = inspect.getsource(container_type)
        except Exception:  # saving the source is optional, so we can ignore any errors
            warnings.warn("Couldn't retrieve source code for container of "
                          "type " + container_type.__name__ + ". It won't be checked "
                          "for correctness upon loading.")
            return
        if original_source != current_source:
            if container_type.dump_patches:
                file_name = container_type.__name__ + '.patch'
                diff = difflib.unified_diff(current_source.split('\n'),
                                            original_source.split('\n'),
                                            source_file,
                                            source_file, lineterm="")
                lines = '\n'.join(diff)
                try:
                    with open(file_name, 'a+') as f:
                        file_size = f.seek(0, 2)
                        f.seek(0)
                        if file_size == 0:
                            f.write(lines)
                        elif file_size != len(lines) or f.read() != lines:
                            raise IOError
                    msg = ("Saved a reverse patch to " + file_name + ". "
                           "Run `patch -p0 < " + file_name + "` to revert your "
                           "changes.")
                except IOError:
                    msg = ("Tried to save a patch, but couldn't create a "
                           "writable file " + file_name + ". Make sure it "
                           "doesn't exist and your working directory is "
                           "writable.")
            else:
                msg = ("you can retrieve the original source code by "
                       "accessing the object's source attribute")
            msg = ("source code of class has changed. {}"
                   .format(msg))
            warnings.warn(msg, SourceChangeWarning)
    deserialized_objects = {}
    def maybe_decode_ascii(bytes_str):
        # When using encoding='bytes' in Py3, some **internal** keys stored as
        # strings in Py2 are loaded as bytes. This function decodes them with
        # ascii encoding, one that Py3 uses by default.
        #
        # NOTE: This should only be used on internal keys (e.g., `typename` and
        #       `location` in `persistent_load` below!
        if isinstance(bytes_str, bytes):
            return bytes_str.decode('ascii')
        return bytes_str
    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]
        if typename == 'module':
            # Ignore containers that don't have any sources saved
            if all(data[1:]):
                _check_container_source(*data)
            return data[0]
        elif typename == 'storage':
            data_type, root_key, location, size, view_metadata = data
            location = maybe_decode_ascii(location)
            if root_key not in deserialized_objects:
                obj = data_type(size)
                deserialized_objects[root_key] = restore_location(obj, location)
            storage = deserialized_objects[root_key]
            if view_metadata is not None:
                view_key, offset, view_size = view_metadata
                if view_key not in deserialized_objects:
                    deserialized_objects[view_key] = storage[offset:offset + view_size]
                return deserialized_objects[view_key]
            else:
                return storage
        else:
            raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
    def legacy_load(f):
        deserialized_objects = {}
        def persistent_load(saved_id):
            if isinstance(saved_id, tuple):
                # Ignore containers that don't have any sources saved
                if all(saved_id[1:]):
                    _check_container_source(*saved_id)
                return saved_id[0]
            return deserialized_objects[int(saved_id)]
        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
                mkdtemp() as tmpdir:
            tar.extract('storages', path=tmpdir)
            with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
                num_storages = pickle_module.load(f, **pickle_load_args)
                for i in range(num_storages):
                    args = pickle_module.load(f, **pickle_load_args)
                    key, location, storage_type = args
                    obj = storage_type._new_with_file(f)
                    obj = restore_location(obj, location)
                    deserialized_objects[key] = obj
                storage_views = pickle_module.load(f, **pickle_load_args)
                for target_cdata, root_cdata, offset, size in storage_views:
                    root = deserialized_objects[root_cdata]
                    deserialized_objects[target_cdata] = root[offset:offset + size]
            tar.extract('tensors', path=tmpdir)
            with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
                num_tensors = pickle_module.load(f, **pickle_load_args)
                for _ in range(num_tensors):
                    args = pickle_module.load(f, **pickle_load_args)
                    key, storage_id, original_tensor_type = args
                    storage = deserialized_objects[storage_id]
                    tensor_type = storage_to_tensor_type(storage)
                    ndim, = struct.unpack('<i', f.read(4))
                    # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
                    f.read(4)
                    size = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
                    stride = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
                    storage_offset, = struct.unpack('<q', f.read(8))
                    tensor = tensor_type().set_(storage, storage_offset, size, stride)
                    deserialized_objects[key] = tensor
            pickle_file = tar.extractfile('pickle')
            unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
            unpickler.persistent_load = persistent_load
            result = unpickler.load()
            return result
    _check_seekable(f)
    f_should_read_directly = _should_read_directly(f)
    if f_should_read_directly and f.tell() == 0:
        try:
            return legacy_load(f)
        except tarfile.TarError:
            if zipfile.is_zipfile(f):
                raise RuntimeError("Please uncompress the file.")
            # if not a tarfile, reset file offset and proceed
            f.seek(0)
    magic_number = pickle_module.load(f, **pickle_load_args)
    if magic_number != MAGIC_NUMBER:
        raise RuntimeError("Invalid magic number; corrupt file?")
    protocol_version = pickle_module.load(f, **pickle_load_args)
    if protocol_version != PROTOCOL_VERSION:
        raise RuntimeError("Invalid protocol version: %s" % protocol_version)
    _sys_info = pickle_module.load(f, **pickle_load_args)
    unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    unpickler.persistent_load = persistent_load
    result = unpickler.load()
    deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
    offset = f.tell() if f_should_read_directly else None
    for key in deserialized_storage_keys:
        assert key in deserialized_objects
        deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
        offset = None
    return result
