TabularPredictor.distill

TabularPredictor.distill(train_data=None, tuning_data=None, augmentation_data=None, time_limit=None, hyperparameters=None, holdout_frac=None, teacher_preds='soft', augment_method='spunge', augment_args={'max_size': 100000, 'size_factor': 5}, models_name_suffix=None, verbosity=None)[source]

[EXPERIMENTAL] Distill AutoGluon’s most accurate ensemble-predictor into single models which are simpler/faster and require less memory/compute. Distillation can produce a model that is more accurate than the same model fit directly on the original training data. After calling distill(), there will be more models available in this Predictor, which can be evaluated using predictor.leaderboard(test_data) and deployed with: predictor.predict(test_data, model=MODEL_NAME). This will raise an exception if cache_data=False was previously set in fit().

NOTE: Until catboost v0.24 is released, distill() with CatBoost students in multiclass classification requires you to first install catboost-dev: pip install catboost-dev

Parameters:
  • train_data (str or TabularDataset or pd.DataFrame, default = None) – Same as train_data argument of fit(). If None, the same training data will be loaded from fit() call used to produce this Predictor.

  • tuning_data (str or TabularDataset or pd.DataFrame, default = None) – Same as tuning_data argument of fit(). If tuning_data = None and train_data = None: the same training/validation splits will be loaded from fit() call used to produce this Predictor, unless bagging/stacking was previously used in which case a new training/validation split is performed.

  • augmentation_data (TabularDataset or pd.DataFrame, default = None) – An optional extra dataset of unlabeled rows that can be used for augmenting the dataset used to fit student models during distillation (ignored if None).

  • time_limit (int, default = None) – Approximately how long (in seconds) the distillation process should run for. If None, no time-constraint will be enforced allowing the distilled models to fully train.

  • hyperparameters (dict or str, default = None) – Specifies which models to use as students and what hyperparameter-values to use for them. Same as hyperparameters argument of fit(). If = None, then student models will use the same hyperparameters from fit() used to produce this Predictor. Note: distillation is currently only supported for [‘GBM’,’NN_TORCH’,’RF’,’CAT’] student models, other models and their hyperparameters are ignored here.

  • holdout_frac (float) – Same as holdout_frac argument of TabularPredictor.fit().

  • teacher_preds (str, default = 'soft') – What form of teacher predictions to distill from (teacher refers to the most accurate AutoGluon ensemble-predictor). If None, we only train with original labels (no data augmentation). If ‘hard’, labels are hard teacher predictions given by: teacher.predict() If ‘soft’, labels are soft teacher predictions given by: teacher.predict_proba() Note: ‘hard’ and ‘soft’ are equivalent for regression problems. If augment_method is not None, teacher predictions are only used to label augmented data (training data keeps original labels). To apply label-smoothing: teacher_preds=’onehot’ will use original training data labels converted to one-hot vectors for multiclass problems (no data augmentation).

  • augment_method (str, default='spunge') –

    Specifies method to use for generating augmented data for distilling student models. Options include:

    None : no data augmentation performed. ‘munge’ : The MUNGE algorithm (https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf). ‘spunge’ : A simpler, more efficient variant of the MUNGE algorithm.

  • augment_args (dict, default = {'size_factor':5, 'max_size': int(1e5)}) –

    Contains the following kwargs that control the chosen augment_method (these are ignored if augment_method=None):

    ’num_augmented_samples’: int, number of augmented datapoints used during distillation. Overrides ‘size_factor’, ‘max_size’ if specified. ‘max_size’: float, the maximum number of augmented datapoints to add (ignored if ‘num_augmented_samples’ specified). ‘size_factor’: float, if n = training data sample-size, we add int(n * size_factor) augmented datapoints, up to ‘max_size’. Larger values in augment_args will slow down the runtime of distill(), and may produce worse results if provided time_limit are too small. You can also pass in kwargs for the spunge_augment, munge_augment functions in autogluon.tabular.augmentation.distill_utils.

  • models_name_suffix (str, default = None) – Optional suffix that can be appended at the end of all distilled student models’ names. Note: all distilled models will contain ‘_DSTL’ substring in their name by default.

  • verbosity (int, default = None) – Controls amount of printed output during distillation (4 = highest, 0 = lowest). Same as verbosity parameter of TabularPredictor. If None, the same verbosity used in previous fit is employed again.

Return type:

List of names (str) corresponding to the distilled models.

Examples

>>> from autogluon.tabular import TabularDataset, TabularPredictor
>>> train_data = TabularDataset('train.csv')
>>> predictor = TabularPredictor(label='class').fit(train_data, auto_stack=True)
>>> distilled_model_names = predictor.distill()
>>> test_data = TabularDataset('test.csv')
>>> ldr = predictor.leaderboard(test_data)
>>> model_to_deploy = distilled_model_names[0]
>>> predictor.predict(test_data, model=model_to_deploy)