Source code for autogluon.core.scheduler.rl_scheduler

import os
import json
import time
import pickle
import logging
import threading
import multiprocessing as mp
from collections import OrderedDict

from .fifo import FIFOScheduler
from .reporter import DistStatusReporter
from .resource import DistributedResource
from .. import Task
from ..decorator import _autogluon_method
from ..searcher import RLSearcher
from ..utils import load, tqdm, try_import_mxnet
from ..utils.default_arguments import check_and_merge_defaults, \
    Integer, Boolean, Float, String, filter_by_key

__all__ = ['RLScheduler']

logger = logging.getLogger(__name__)


_ARGUMENT_KEYS = {
    'controller_lr', 'ema_baseline_decay', 'controller_resource',
    'controller_batch_size', 'sync'}

_DEFAULT_OPTIONS = {
    'resume': False,
    'reward_attr': 'accuracy',
    'checkpoint': './exp/checkpoint.ag',
    'controller_lr': 1e-3,
    'ema_baseline_decay': 0.95,
    'controller_resource': {'num_cpus': 0, 'num_gpus': 0},
    'controller_batch_size': 1,
    'sync': True}

_CONSTRAINTS = {
    'resume': Boolean(),
    'reward_attr': String(),
    'checkpoint': String(),
    'controller_lr': Float(0.0, None),
    'ema_baseline_decay': Float(0.0, 1.0),
    'controller_batch_size': Integer(1, None),
    'sync': Boolean()}


[docs]class RLScheduler(FIFOScheduler): r"""Scheduler that uses Reinforcement Learning with a LSTM controller created based on the provided search spaces Parameters ---------- train_fn : callable A task launch function for training. Note: please add the `@ag.args` decorater to the original function. args : object (optional) Default arguments for launching train_fn. resource : dict Computation resources. For example, `{'num_cpus':2, 'num_gpus':1}` searcher : object (optional) Autogluon searcher. For example, autogluon.searcher.RandomSearcher time_attr : str A training result attr to use for comparing time. Note that you can pass in something non-temporal such as `training_epoch` as a measure of progress, the only requirement is that the attribute should increase monotonically. reward_attr : str The training result objective value attribute. As with `time_attr`, this may refer to any objective value. Stopping procedures will use this attribute. controller_resource : int Batch size for training controllers. dist_ip_addrs : list of str IP addresses of remote machines. Examples -------- >>> import numpy as np >>> import autogluon.core as ag >>> >>> @ag.args( ... lr=ag.space.Real(1e-3, 1e-2, log=True), ... wd=ag.space.Real(1e-3, 1e-2)) >>> def train_fn(args, reporter): ... print('lr: {}, wd: {}'.format(args.lr, args.wd)) ... for e in range(10): ... dummy_accuracy = 1 - np.power(1.8, -np.random.uniform(e, 2*e)) ... reporter(epoch=e+1, accuracy=dummy_accuracy, lr=args.lr, wd=args.wd) ... >>> scheduler = ag.scheduler.RLScheduler(train_fn, ... resource={'num_cpus': 2, 'num_gpus': 0}, ... num_trials=20, ... reward_attr='accuracy', ... time_attr='epoch') >>> scheduler.run() >>> scheduler.join_jobs() >>> scheduler.get_training_curves(plot=True) """ def __init__(self, train_fn, **kwargs): try_import_mxnet() import mxnet as mx assert isinstance(train_fn, _autogluon_method), 'Please use @ag.args ' + \ 'to decorate your training script.' # Check values and impute default values (only for arguments new to # this class) kwargs = check_and_merge_defaults( kwargs, set(), _DEFAULT_OPTIONS, _CONSTRAINTS, dict_name='scheduler_options') resume = kwargs['resume'] self.ema_baseline_decay = kwargs['ema_baseline_decay'] self.sync = kwargs['sync'] # create RL searcher if not passed searcher = kwargs.get('searcher') if not isinstance(searcher, RLSearcher): if searcher is not None: logger.warning("Argument 'searcher' must be of type RLSearcher. Ignoring 'searcher' and creating searcher here.") kwargs['searcher'] = RLSearcher( train_fn.kwspaces, reward_attribute=kwargs['reward_attr']) # Pass resume=False here. Resume needs members of this object to be # created kwargs['resume'] = False super().__init__( train_fn=train_fn, **filter_by_key(kwargs, _ARGUMENT_KEYS)) # reserve controller computation resource on master node master_node = self.managers.remote_manager.get_master_node() controller_resource = kwargs['controller_resource'] self.controller_resource = DistributedResource(**controller_resource) assert self.managers.resource_manager.reserve_resource( master_node, self.controller_resource),\ "Not Enough Resource on Master Node for Training Controller" if controller_resource['num_gpus'] > 0: self.controller_ctx = [ mx.gpu(i) for i in self.controller_resource.gpu_ids] else: self.controller_ctx = [mx.cpu()] # controller setup self.controller = self.searcher.controller self.controller.collect_params().reset_ctx(self.controller_ctx) controller_batch_size = kwargs['controller_batch_size'] learning_rate = kwargs['controller_lr'] * controller_batch_size self.controller_optimizer = mx.gluon.Trainer( self.controller.collect_params(), 'adam', optimizer_params={'learning_rate': learning_rate}) self.controller_batch_size = controller_batch_size self.baseline = None self.lock = mp.Lock() # async buffers if not self.sync: self.mp_count = mp.Value('i', 0) self.mp_seed = mp.Value('i', 0) self.mp_fail = mp.Value('i', 0) if resume: checkpoint = kwargs.get('checkpoint') if os.path.isfile(checkpoint): self.load_state_dict(load(checkpoint)) else: msg = 'checkpoint path {} is not available for resume.'.format(checkpoint) logger.exception(msg)
[docs] def run(self, **kwargs): """Run multiple number of trials """ self.num_trials = kwargs.get('num_trials', self.num_trials) logger.info('Starting Experiments') logger.info('Num of Finished Tasks is {}'.format(self.num_finished_tasks)) logger.info('Num of Pending Tasks is {}'.format(self.num_trials - self.num_finished_tasks)) if self.sync: self._run_sync() else: self._run_async()
def _run_sync(self): try_import_mxnet() import mxnet as mx decay = self.ema_baseline_decay for i in tqdm(range(self.num_trials // self.controller_batch_size + 1)): with mx.autograd.record(): # sample controller_batch_size number of configurations batch_size = self.num_trials % self.num_trials \ if i == self.num_trials // self.controller_batch_size \ else self.controller_batch_size if batch_size == 0: continue configs, log_probs, entropies = self.controller.sample( batch_size, with_details=True) # schedule the training tasks and gather the reward rewards = self.sync_schedule_tasks(configs) # substract baseline if self.baseline is None: self.baseline = rewards[0] avg_rewards = mx.nd.array([reward - self.baseline for reward in rewards], ctx=self.controller.context) # EMA baseline for reward in rewards: self.baseline = decay * self.baseline + (1 - decay) * reward # negative policy gradient log_probs = log_probs.sum(axis=1) loss = - log_probs * avg_rewards#.reshape(-1, 1) loss = loss.sum() # or loss.mean() # update loss.backward() self.controller_optimizer.step(batch_size) logger.debug('controller loss: {}'.format(loss.asscalar())) def _run_async(self): try_import_mxnet() import mxnet as mx def _async_run_trial(): self.mp_count.value += 1 self.mp_seed.value += 1 seed = self.mp_seed.value mx.random.seed(seed) with mx.autograd.record(): # sample one configuration with self.lock: config, log_prob, entropy = self.controller.sample(with_details=True) config = config[0] task = Task(self.train_fn, {'args': self.args, 'config': config}, DistributedResource(**self.resource)) # start training task reporter = DistStatusReporter(remote=task.resources.node) task.args['reporter'] = reporter task_thread = self.add_job(task) # run reporter last_result = None config = task.args['config'] while task_thread.is_alive(): reported_result = reporter.fetch() if reported_result.get('done', False): reporter.move_on() task_thread.join() break self._add_training_result(task.task_id, reported_result, task.args['config']) reporter.move_on() last_result = reported_result self.searcher.update(config, **last_result) reward = last_result[self._reward_attr] with self.lock: if self.baseline is None: self.baseline = reward avg_reward = mx.nd.array([reward - self.baseline], ctx=self.controller.context) # negative policy gradient with self.lock: loss = -log_prob * avg_reward.reshape(-1, 1) loss = loss.sum() # update print('loss', loss) with self.lock: try: loss.backward() self.controller_optimizer.step(1) except Exception: self.mp_fail.value += 1 logger.warning('Exception during backward {}.'.format(self.mp_fail.value)) self.mp_count.value -= 1 # ema with self.lock: decay = self.ema_baseline_decay self.baseline = decay * self.baseline + (1 - decay) * reward reporter_threads = [] for i in range(self.num_trials): while self.mp_count.value >= self.controller_batch_size: time.sleep(0.2) #_async_run_trial() reporter_thread = threading.Thread(target=_async_run_trial) reporter_thread.start() reporter_threads.append(reporter_thread) for p in reporter_threads: p.join() def sync_schedule_tasks(self, configs): rewards = [] results = {} def _run_reporter(task, task_job, reporter): last_result = None config = task.args['config'] while not task_job.done(): reported_result = reporter.fetch() if 'traceback' in reported_result: logger.exception(reported_result['traceback']) reporter.move_on() break if reported_result.get('done', False): reporter.move_on() break self._add_training_result(task.task_id, reported_result, task.args['config']) reporter.move_on() last_result = reported_result if last_result is not None: self.searcher.update(config, **last_result) with self.lock: results[pickle.dumps(config)] = \ last_result[self._reward_attr] # launch the tasks tasks = [] task_jobs = [] reporter_threads = [] for config in configs: logger.debug('scheduling config: {}'.format(config)) # create task task = Task(self.train_fn, {'args': self.args, 'config': config}, DistributedResource(**self.resource)) reporter = DistStatusReporter() task.args['reporter'] = reporter task_job = self.add_job(task) # run reporter reporter_thread = threading.Thread(target=_run_reporter, args=(task, task_job, reporter)) reporter_thread.start() tasks.append(task) task_jobs.append(task_job) reporter_threads.append(reporter_thread) for p1, p2 in zip(task_jobs, reporter_threads): p1.result() p2.join() with self.managers.lock: for task in tasks: self.finished_tasks.append({'TASK_ID': task.task_id, 'Config': task.args['config']}) if self._checkpoint is not None: logger.debug('Saving Checkerpoint') self.save() for config in configs: rewards.append(results[pickle.dumps(config)]) return rewards
[docs] def add_job(self, task, **kwargs): """Adding a training task to the scheduler. Args: task (:class:`autogluon.scheduler.Task`): a new training task """ cls = RLScheduler cls.managers.request_resources(task.resources) # main process job = cls.jobs.start_distributed_job(task, cls.managers) return job
def join_tasks(self): pass
[docs] def state_dict(self, destination=None): """Returns a dictionary containing a whole state of the Scheduler Examples -------- >>> ag.save(scheduler.state_dict(), 'checkpoint.ag') """ if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() logger.debug('\nState_Dict self.finished_tasks: {}'.format(self.finished_tasks)) destination['finished_tasks'] = pickle.dumps(self.finished_tasks) destination['baseline'] = pickle.dumps(self.baseline) destination['TASK_ID'] = Task.TASK_ID.value destination['searcher'] = self.searcher.state_dict() destination['training_history'] = json.dumps(self.training_history) if self.visualizer == 'mxboard' or self.visualizer == 'tensorboard': destination['visualizer'] = json.dumps(self.mxboard._scalar_dict) return destination
[docs] def load_state_dict(self, state_dict): """Load from the saved state dict. Examples -------- >>> scheduler.load_state_dict(ag.load('checkpoint.ag')) """ self.finished_tasks = pickle.loads(state_dict['finished_tasks']) #self.baseline = pickle.loads(state_dict['baseline']) Task.set_id(state_dict['TASK_ID']) self.searcher.load_state_dict(state_dict['searcher']) self.training_history = json.loads(state_dict['training_history']) if self.visualizer == 'mxboard' or self.visualizer == 'tensorboard': self.mxboard._scalar_dict = json.loads(state_dict['visualizer']) logger.debug('Loading Searcher State {}'.format(self.searcher))