Source code for autogluon.core.utils.files

import contextlib
import os
from pathlib import Path
import requests
import shutil
import hashlib
import zipfile
import logging
from tqdm import tqdm
import tempfile

logger = logging.getLogger(__name__)

__all__ = ['unzip', 'download']


[docs]def unzip(zip_file_path, root=os.path.expanduser('./')): """Unzips files located at `zip_file_path` into parent directory specified by `root`. """ folders = [] with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(root) for name in zf.namelist(): folder = Path(name).parts[0] if folder not in folders: folders.append(folder) folders = folders[0] if len(folders) == 1 else tuple(folders) return folders
[docs]def download(url, path=None, overwrite=False, sha1_hash=None): """Download files from a given URL. Parameters ---------- url : str URL where file is located path : str, optional Destination path to store downloaded file. By default stores to the current directory with same name as in url. overwrite : bool, optional Whether to overwrite destination file if one already exists at this location. sha1_hash : str, optional Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified but doesn't match). Returns ------- str The file path of the downloaded file. """ if path is None: fname = url.split('/')[-1] else: path = os.path.expanduser(path) if os.path.isdir(path): fname = os.path.join(path, url.split('/')[-1]) else: fname = path if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) logger.info('Downloading %s from %s...'%(fname, url)) r = requests.get(url, stream=True) if r.status_code != 200: raise RuntimeError("Failed downloading url %s"%url) total_length = r.headers.get('content-length') with open(fname, 'wb') as f: if total_length is None: # no content length header for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) else: total_length = int(total_length) for chunk in tqdm(r.iter_content(chunk_size=1024), total=int(total_length / 1024. + 0.5), unit='KB', unit_scale=False, dynamic_ncols=True): f.write(chunk) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(fname)) return fname
def check_sha1(filename, sha1_hash): """Check whether the sha1 hash of the file content matches the expected hash. Parameters ---------- filename : str Path to the file. sha1_hash : str Expected sha1 hash in hexadecimal digits. Returns ------- bool Whether the file content matches the expected hash. """ sha1 = hashlib.sha1() with open(filename, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) return sha1.hexdigest() == sha1_hash @contextlib.contextmanager def make_temp_directory(): temp_dir = tempfile.mkdtemp() try: yield temp_dir finally: shutil.rmtree(temp_dir)