Source code for autogluon.core.utils.files
import contextlib
import os
from pathlib import Path
import requests
import shutil
import hashlib
import zipfile
import logging
from tqdm import tqdm
import tempfile
logger = logging.getLogger(__name__)
__all__ = ['unzip', 'download']
[docs]def unzip(zip_file_path, root=os.path.expanduser('./')):
    """Unzips files located at `zip_file_path` into parent directory specified by `root`.
    """
    folders = []
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
        for name in zf.namelist():
            folder = Path(name).parts[0]
            if folder not in folders:
                folders.append(folder)
    folders = folders[0] if len(folders) == 1 else tuple(folders)
    return folders
[docs]def download(url, path=None, overwrite=False, sha1_hash=None):
    """Download files from a given URL.
    Parameters
    ----------
    url : str
        URL where file is located
    path : str, optional
        Destination path to store downloaded file. By default stores to the
        current directory with same name as in url.
    overwrite : bool, optional
        Whether to overwrite destination file if one already exists at this location.
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified
        but doesn't match).
    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path
    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        logger.info('Downloading %s from %s...'%(fname, url))
        r = requests.get(url, stream=True)
        if r.status_code != 200:
            raise RuntimeError("Failed downloading url %s"%url)
        total_length = r.headers.get('content-length')
        with open(fname, 'wb') as f:
            if total_length is None: # no content length header
                for chunk in r.iter_content(chunk_size=1024):
                    if chunk: # filter out keep-alive new chunks
                        f.write(chunk)
            else:
                total_length = int(total_length)
                for chunk in tqdm(r.iter_content(chunk_size=1024),
                                  total=int(total_length / 1024. + 0.5),
                                  unit='KB', unit_scale=False, dynamic_ncols=True):
                    f.write(chunk)
        if sha1_hash and not check_sha1(fname, sha1_hash):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(fname))
    return fname
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.
    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)
    return sha1.hexdigest() == sha1_hash
@contextlib.contextmanager
def make_temp_directory():
    temp_dir = tempfile.mkdtemp()
    try:
        yield temp_dir
    finally:
        shutil.rmtree(temp_dir)
