AutoGluon Tabular - In Depth

Open In Colab Open In SageMaker Studio Lab

Tip: If you are new to AutoGluon, review Predicting Columns in a Table - Quick Start to learn the basics of the AutoGluon API. To learn how to add your own custom models to the set that AutoGluon trains, tunes, and ensembles, review Adding a custom model to AutoGluon.

This tutorial describes how you can exert greater control when using AutoGluon’s fit() or predict(). Recall that to maximize predictive performance, you should first try TabularPredictor() and fit() with all default arguments. Then, consider non-default arguments for TabularPredictor(eval_metric=...), and fit(presets=...). Later, you can experiment with other arguments to fit() covered in this in-depth tutorial like hyperparameter_tune_kwargs, hyperparameters, num_stack_levels, num_bag_folds, num_bag_sets, etc.

Using the same census data table as in the Predicting Columns in a Table - Quick Start tutorial, we’ll now predict the occupation of an individual - a multiclass classification problem. Start by importing AutoGluon’s TabularPredictor and TabularDataset, and loading the data.

from autogluon.tabular import TabularDataset, TabularPredictor

import numpy as np

train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
subsample_size = 1000  # subsample subset of data for faster demo, try setting this to much larger values
train_data = train_data.sample(n=subsample_size, random_state=0)
print(train_data.head())

label = 'occupation'
print("Summary of occupation column: \n", train_data['occupation'].describe())

test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
y_test = test_data[label]
test_data_nolabel = test_data.drop(columns=[label])  # delete label column

metric = 'accuracy' # we specify eval-metric just for demo (unnecessary as it's the default)
age workclass  fnlwgt      education  education-num  \
6118    51   Private   39264   Some-college             10   
23204   58   Private   51662           10th              6   
29590   40   Private  326310   Some-college             10   
18116   37   Private  222450        HS-grad              9   
33964   62   Private  109190      Bachelors             13   

            marital-status        occupation    relationship    race      sex  \
6118    Married-civ-spouse   Exec-managerial            Wife   White   Female   
23204   Married-civ-spouse     Other-service            Wife   White   Female   
29590   Married-civ-spouse      Craft-repair         Husband   White     Male   
18116        Never-married             Sales   Not-in-family   White     Male   
33964   Married-civ-spouse   Exec-managerial         Husband   White     Male   

       capital-gain  capital-loss  hours-per-week  native-country   class  
6118              0             0              40   United-States    >50K  
23204             0             0               8   United-States   <=50K  
29590             0             0              44   United-States   <=50K  
18116             0          2339              40     El-Salvador   <=50K  
33964         15024             0              40   United-States    >50K  
Summary of occupation column: 
 count              1000
unique               15
top        Craft-repair
freq                142
Name: occupation, dtype: object

Specifying hyperparameters and tuning them

Note: We don’t recommend doing hyperparameter-tuning with AutoGluon in most cases. AutoGluon achieves its best performance without hyperparameter tuning and simply specifying presets="best_quality".

We first demonstrate hyperparameter-tuning and how you can provide your own validation dataset that AutoGluon internally relies on to: tune hyperparameters, early-stop iterative training, and construct model ensembles. One reason you may specify validation data is when future test data will stem from a different distribution than training data (and your specified validation data is more representative of the future data that will likely be encountered).

If you don’t have a strong reason to provide your own validation dataset, we recommend you omit the tuning_data argument. This lets AutoGluon automatically select validation data from your provided training set (it uses smart strategies such as stratified sampling). For greater control, you can specify the holdout_frac argument to tell AutoGluon what fraction of the provided training data to hold out for validation.

Caution: Since AutoGluon tunes internal knobs based on this validation data, performance estimates reported on this data may be over-optimistic. For unbiased performance estimates, you should always call predict() on a separate dataset (that was never passed to fit()), as we did in the previous Quick-Start tutorial. We also emphasize that most options specified in this tutorial are chosen to minimize runtime for the purposes of demonstration and you should select more reasonable values in order to obtain high-quality models.

fit() trains neural networks and various types of tree ensembles by default. You can specify various hyperparameter values for each type of model. For each hyperparameter, you can either specify a single fixed value, or a search space of values to consider during hyperparameter optimization. Hyperparameters which you do not specify are left at default settings chosen automatically by AutoGluon, which may be fixed values or search spaces.

Refer to the Search Space documentation to learn more about AutoGluon search space.

from autogluon.common import space

nn_options = {  # specifies non-default hyperparameter values for neural network models
    'num_epochs': 10,  # number of training epochs (controls training time of NN models)
    'learning_rate': space.Real(1e-4, 1e-2, default=5e-4, log=True),  # learning rate used in training (real-valued hyperparameter searched on log-scale)
    'activation': space.Categorical('relu', 'softrelu', 'tanh'),  # activation function used in NN (categorical hyperparameter, default = first entry)
    'dropout_prob': space.Real(0.0, 0.5, default=0.1),  # dropout probability (real-valued hyperparameter)
}

gbm_options = {  # specifies non-default hyperparameter values for lightGBM gradient boosted trees
    'num_boost_round': 100,  # number of boosting rounds (controls training time of GBM models)
    'num_leaves': space.Int(lower=26, upper=66, default=36),  # number of leaves in trees (integer hyperparameter)
}

hyperparameters = {  # hyperparameters of each model type
                   'GBM': gbm_options,
                   'NN_TORCH': nn_options,  # NOTE: comment this line out if you get errors on Mac OSX
                  }  # When these keys are missing from hyperparameters dict, no models of that type are trained

time_limit = 2*60  # train various models for ~2 min
num_trials = 5  # try at most 5 different hyperparameter configurations for each type of model
search_strategy = 'auto'  # to tune hyperparameters using random search routine with a local scheduler

hyperparameter_tune_kwargs = {  # HPO is not performed unless hyperparameter_tune_kwargs is specified
    'num_trials': num_trials,
    'scheduler' : 'local',
    'searcher': search_strategy,
}  # Refer to TabularPredictor.fit docstring for all valid values

predictor = TabularPredictor(label=label, eval_metric=metric).fit(
    train_data,
    time_limit=time_limit,
    hyperparameters=hyperparameters,
    hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
)
Fitted model: NeuralNetTorch/1535251a ...
0.365	 = Validation score   (accuracy)
3.25s	 = Training   runtime
0.01s	 = Validation runtime
Fitted model: NeuralNetTorch/b5c75068 ...
0.33	 = Validation score   (accuracy)
3.89s	 = Training   runtime
0.01s	 = Validation runtime
Fitted model: NeuralNetTorch/49f5f71e ...
0.34	 = Validation score   (accuracy)
5.14s	 = Training   runtime
0.02s	 = Validation runtime
Fitted model: NeuralNetTorch/a05927f4 ...
0.345	 = Validation score   (accuracy)
3.49s	 = Training   runtime
0.01s	 = Validation runtime
Fitted model: NeuralNetTorch/9360d9d9 ...
0.36	 = Validation score   (accuracy)
3.46s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 119.91s of the 95.62s of remaining time.
Ensemble Weights: {'LightGBM/T3': 1.0}
0.375	 = Validation score   (accuracy)
0.03s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 24.44s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 43593.0 rows/s (200 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_022901")

We again demonstrate how to use the trained models to predict on the test data.

y_pred = predictor.predict(test_data_nolabel)
print("Predictions:  ", list(y_pred)[:5])
perf = predictor.evaluate(test_data, auxiliary_metrics=False)
Predictions:   [' Other-service', ' Craft-repair', ' Exec-managerial', ' Sales', ' Other-service']

Use the following to view a summary of what happened during fit(). Now this command will show details of the hyperparameter-tuning process for each type of model:

results = predictor.fit_summary()
*** Summary of fit() ***
Estimated performance of each model:
                      model  score_val eval_metric  pred_time_val  fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0               LightGBM/T3      0.375    accuracy       0.003743  0.361251                0.003743           0.361251            1       True          3
1               LightGBM/T5      0.375    accuracy       0.004319  0.527203                0.004319           0.527203            1       True          5
2       WeightedEnsemble_L2      0.375    accuracy       0.004588  0.389838                0.000845           0.028587            2       True         11
3               LightGBM/T1      0.370    accuracy       0.003670  0.709906                0.003670           0.709906            1       True          1
4   NeuralNetTorch/1535251a      0.365    accuracy       0.009585  3.249166                0.009585           3.249166            1       True          6
5               LightGBM/T4      0.360    accuracy       0.005905  0.578249                0.005905           0.578249            1       True          4
6   NeuralNetTorch/9360d9d9      0.360    accuracy       0.010587  3.464687                0.010587           3.464687            1       True         10
7               LightGBM/T2      0.355    accuracy       0.004083  0.609532                0.004083           0.609532            1       True          2
8   NeuralNetTorch/a05927f4      0.345    accuracy       0.011839  3.487161                0.011839           3.487161            1       True          9
9   NeuralNetTorch/49f5f71e      0.340    accuracy       0.015467  5.139667                0.015467           5.139667            1       True          8
10  NeuralNetTorch/b5c75068      0.330    accuracy       0.012736  3.888344                0.012736           3.888344            1       True          7
Number of models trained: 11
Types of models trained:
{'LGBModel', 'WeightedEnsembleModel', 'TabularNeuralNetTorchModel'}
Bagging used: False 
Multi-layer stack-ensembling used: False 
Feature Metadata (Processed):
(raw dtype, special dtypes):
('category', [])  : 6 | ['workclass', 'education', 'marital-status', 'relationship', 'race', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 2 | ['sex', 'class']
*** End of fit() summary ***
/home/ci/autogluon/core/src/autogluon/core/utils/plots.py:169: UserWarning: AutoGluon summary plots cannot be created because bokeh is not installed. To see plots, please do: "pip install bokeh==2.0.1"
  warnings.warn('AutoGluon summary plots cannot be created because bokeh is not installed. To see plots, please do: "pip install bokeh==2.0.1"')

In the above example, the predictive performance may be poor because we specified very little training to ensure quick runtimes. You can call fit() multiple times while modifying the above settings to better understand how these choices affect performance outcomes. For example: you can comment out the train_data.head command or increase subsample_size to train using a larger dataset, increase the num_epochs and num_boost_round hyperparameters, and increase the time_limit (which you should do for all code in these tutorials). To see more detailed output during the execution of fit(), you can also pass in the argument: verbosity = 3.

Model ensembling with stacking/bagging

Beyond hyperparameter-tuning with a correctly-specified evaluation metric, two other methods to boost predictive performance are bagging and stack-ensembling. You’ll often see performance improve if you specify num_bag_folds = 5-10, num_stack_levels = 1 in the call to fit(), but this will increase training times and memory/disk usage.

label = 'class'  # Now lets predict the "class" column (binary classification)
test_data_nolabel = test_data.drop(columns=[label])
y_test = test_data[label]
save_path = 'agModels-predictClass'  # folder where to store trained models

predictor = TabularPredictor(label=label, eval_metric=metric).fit(train_data,
    num_bag_folds=5, num_bag_sets=1, num_stack_levels=1,
    hyperparameters = {'NN_TORCH': {'num_epochs': 2}, 'GBM': {'num_boost_round': 20}},  # last  argument is just for quick demo here, omit it in real applications
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_022925"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       27.83 GB / 30.95 GB (89.9%)
Disk Space Avail:   213.26 GB / 255.99 GB (83.3%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_022925"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    28495.09 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.11s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{'num_epochs': 2}],
	'GBM': [{'num_boost_round': 20}],
}
AutoGluon will fit 2 stack levels (L1 to L2) ...
Fitting 2 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L1 ...
Fitting 5 child models (S1F1 - S1F5) | Fitting with ParallelLocalFoldFittingStrategy (5 workers, per: cpus=1, gpus=0, memory=0.01%)
0.823	 = Validation score   (accuracy)
1.6s	 = Training   runtime
0.02s	 = Validation runtime
Fitting model: NeuralNetTorch_BAG_L1 ...
Fitting 5 child models (S1F1 - S1F5) | Fitting with ParallelLocalFoldFittingStrategy (5 workers, per: cpus=1, gpus=0, memory=0.00%)
0.744	 = Validation score   (accuracy)
4.03s	 = Training   runtime
0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
Ensemble Weights: {'LightGBM_BAG_L1': 1.0}
0.823	 = Validation score   (accuracy)
0.03s	 = Training   runtime
0.0s	 = Validation runtime
Fitting 2 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L2 ...
Fitting 5 child models (S1F1 - S1F5) | Fitting with ParallelLocalFoldFittingStrategy (5 workers, per: cpus=1, gpus=0, memory=0.01%)
0.826	 = Validation score   (accuracy)
0.94s	 = Training   runtime
0.02s	 = Validation runtime
Fitting model: NeuralNetTorch_BAG_L2 ...
Fitting 5 child models (S1F1 - S1F5) | Fitting with ParallelLocalFoldFittingStrategy (5 workers, per: cpus=1, gpus=0, memory=0.00%)
0.748	 = Validation score   (accuracy)
3.78s	 = Training   runtime
0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L3 ...
Ensemble Weights: {'LightGBM_BAG_L2': 0.889, 'LightGBM_BAG_L1': 0.111}
0.827	 = Validation score   (accuracy)
0.07s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 19.78s ... Best model: WeightedEnsemble_L3 | Estimated inference throughput: 2009.2 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (1000 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_022925")

You should not provide tuning_data when stacking/bagging, and instead provide all your available data as train_data (which AutoGluon will split in more intellgent ways). num_bag_sets controls how many times the k-fold bagging process is repeated to further reduce variance (increasing this may further boost accuracy but will substantially increase training times, inference latency, and memory/disk usage). Rather than manually searching for good bagging/stacking values yourself, AutoGluon will automatically select good values for you if you specify auto_stack instead (which is used in the best_quality preset):

# Lets also specify the "balanced_accuracy" metric
predictor = TabularPredictor(label=label, eval_metric='balanced_accuracy', path=save_path).fit(
    train_data, auto_stack=True,
    calibrate_decision_threshold=False,  # Disabling for demonstration in next section
    hyperparameters={'FASTAI': {'num_epochs': 10}, 'GBM': {'num_boost_round': 200}}  # last 2 arguments are for quick demo, omit them in real applications
)
predictor.leaderboard(test_data)
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       25.98 GB / 30.95 GB (83.9%)
Disk Space Avail:   213.26 GB / 255.99 GB (83.3%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Stack configuration (auto_stack=True): num_stack_levels=0, num_bag_folds=8, num_bag_sets=1
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    26599.48 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.12s ...
AutoGluon will gauge predictive performance using evaluation metric: 'balanced_accuracy'
To change this, specify the eval_metric parameter of Predictor()
User-specified model hyperparameters to be fit:
{
	'FASTAI': [{'num_epochs': 10}],
	'GBM': [{'num_boost_round': 200}],
}
Fitting 2 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L1 ...
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.01%)
0.7764	 = Validation score   (balanced_accuracy)
0.96s	 = Training   runtime
0.04s	 = Validation runtime
Fitting model: NeuralNetFastAI_BAG_L1 ...
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.00%)
0.7414	 = Validation score   (balanced_accuracy)
5.54s	 = Training   runtime
0.11s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
Ensemble Weights: {'LightGBM_BAG_L1': 1.0}
0.7764	 = Validation score   (balanced_accuracy)
0.04s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 14.24s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 2775.2 rows/s (125 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass")
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1 0.743784 0.776399 balanced_accuracy 0.286168 0.044906 0.961846 0.286168 0.044906 0.961846 1 True 1
1 WeightedEnsemble_L2 0.743784 0.776399 balanced_accuracy 0.287503 0.045989 1.002118 0.001335 0.001083 0.040272 2 True 3
2 NeuralNetFastAI_BAG_L1 0.724629 0.741368 balanced_accuracy 2.030228 0.109789 5.542246 2.030228 0.109789 5.542246 1 True 2

Often stacking/bagging will produce superior accuracy than hyperparameter-tuning, but you may try combining both techniques (note: specifying presets='best_quality' in fit() simply sets auto_stack=True).

Decision Threshold Calibration

Major metric score improvements can be achieved in binary classification for metrics such as "f1" and "balanced_accuracy" by adjusting the prediction decision threshold via calibrate_decision_threshold to a value other than the default 0.5.

Below is an example of the "balanced_accuracy" score achieved on the test data with and without calibrating the decision threshold:

print(f'Prior to calibration (predictor.decision_threshold={predictor.decision_threshold}):')
scores = predictor.evaluate(test_data)

calibrated_decision_threshold = predictor.calibrate_decision_threshold()
predictor.set_decision_threshold(calibrated_decision_threshold)

print(f'After calibration (predictor.decision_threshold={predictor.decision_threshold}):')
scores_calibrated = predictor.evaluate(test_data)
Prior to calibration (predictor.decision_threshold=0.5):
After calibration (predictor.decision_threshold=0.25):
Calibrating decision threshold to optimize metric balanced_accuracy | Checking 51 thresholds...
Calibrating decision threshold via fine-grained search | Checking 38 thresholds...
Base Threshold: 0.500	| val: 0.7764
Best Threshold: 0.250	| val: 0.7926
Updating predictor.decision_threshold from 0.5 -> 0.25
	This will impact how prediction probabilities are converted to predictions in binary classification.
	Prediction probabilities of the positive class >0.25 will be predicted as the positive class ( >50K). This can significantly impact metric scores.
	You can update this value via `predictor.set_decision_threshold`.
	You can calculate an optimal decision threshold on the validation data via `predictor.calibrate_decision_threshold()`.
for metric_name in scores:
    metric_score = scores[metric_name]
    metric_score_calibrated = scores_calibrated[metric_name]
    decision_threshold = predictor.decision_threshold
    print(f'decision_threshold={decision_threshold:.3f}\t| metric="{metric_name}"'
          f'\n\ttest_score uncalibrated: {metric_score:.4f}'
          f'\n\ttest_score   calibrated: {metric_score_calibrated:.4f}'
          f'\n\ttest_score        delta: {metric_score_calibrated-metric_score:.4f}')
decision_threshold=0.250	| metric="balanced_accuracy"
	test_score uncalibrated: 0.7438
	test_score   calibrated: 0.8120
	test_score        delta: 0.0682
decision_threshold=0.250	| metric="accuracy"
	test_score uncalibrated: 0.8472
	test_score   calibrated: 0.8162
	test_score        delta: -0.0310
decision_threshold=0.250	| metric="mcc"
	test_score uncalibrated: 0.5457
	test_score   calibrated: 0.5654
	test_score        delta: 0.0197
decision_threshold=0.250	| metric="roc_auc"
	test_score uncalibrated: 0.8990
	test_score   calibrated: 0.8990
	test_score        delta: 0.0000
decision_threshold=0.250	| metric="f1"
	test_score uncalibrated: 0.6294
	test_score   calibrated: 0.6749
	test_score        delta: 0.0454
decision_threshold=0.250	| metric="precision"
	test_score uncalibrated: 0.7411
	test_score   calibrated: 0.5814
	test_score        delta: -0.1597
decision_threshold=0.250	| metric="recall"
	test_score uncalibrated: 0.5470
	test_score   calibrated: 0.8041
	test_score        delta: 0.2571

Notice that calibrating for “balanced_accuracy” majorly improved the “balanced_accuracy” metric score, but it harmed the “accuracy” score. Threshold calibration will often result in a tradeoff between performance on different metrics, and the user should keep this in mind.

Instead of calibrating for “balanced_accuracy” specifically, we can calibrate for any metric if we want to maximize the score of that metric:

predictor.set_decision_threshold(0.5)  # Reset decision threshold
for metric_name in ['f1', 'balanced_accuracy', 'mcc']:
    metric_score = predictor.evaluate(test_data, silent=True)[metric_name]
    calibrated_decision_threshold = predictor.calibrate_decision_threshold(metric=metric_name, verbose=False)
    metric_score_calibrated = predictor.evaluate(
        test_data, decision_threshold=calibrated_decision_threshold, silent=True
    )[metric_name]
    print(f'decision_threshold={calibrated_decision_threshold:.3f}\t| metric="{metric_name}"'
          f'\n\ttest_score uncalibrated: {metric_score:.4f}'
          f'\n\ttest_score   calibrated: {metric_score_calibrated:.4f}'
          f'\n\ttest_score        delta: {metric_score_calibrated-metric_score:.4f}')
decision_threshold=0.500	| metric="f1"
	test_score uncalibrated: 0.6294
	test_score   calibrated: 0.6294
	test_score        delta: 0.0000
decision_threshold=0.250	| metric="balanced_accuracy"
	test_score uncalibrated: 0.7438
	test_score   calibrated: 0.8120
	test_score        delta: 0.0682
decision_threshold=0.500	| metric="mcc"
	test_score uncalibrated: 0.5457
	test_score   calibrated: 0.5457
	test_score        delta: 0.0000
Updating predictor.decision_threshold from 0.25 -> 0.5
	This will impact how prediction probabilities are converted to predictions in binary classification.
	Prediction probabilities of the positive class >0.5 will be predicted as the positive class ( >50K). This can significantly impact metric scores.
	You can update this value via `predictor.set_decision_threshold`.
	You can calculate an optimal decision threshold on the validation data via `predictor.calibrate_decision_threshold()`.

Instead of calibrating the decision threshold post-fit, you can have it automatically occur during the fit call by specifying the fit parameter predictor.fit(..., calibrate_decision_threshold=True).

Luckily, AutoGluon will automatically apply decision threshold calibration when beneficial, as the default value is calibrate_decision_threshold="auto". We recommend keeping this value as the default in most cases.

Additional usage examples are below:

# Will use the decision_threshold specified in `predictor.decision_threshold`, can be set via `predictor.set_decision_threshold`
# y_pred = predictor.predict(test_data)
# y_pred_08 = predictor.predict(test_data, decision_threshold=0.8)  # Specify a specific threshold to use only for this predict

# y_pred_proba = predictor.predict_proba(test_data)
# y_pred = predictor.predict_from_proba(y_pred_proba)  # Identical output to calling .predict(test_data)
# y_pred_08 = predictor.predict_from_proba(y_pred_proba, decision_threshold=0.8)  # Identical output to calling .predict(test_data, decision_threshold=0.8)

Prediction options (inference)

Even if you’ve started a new Python session since last calling fit(), you can still load a previously trained predictor from disk:

predictor = TabularPredictor.load(save_path)  # `predictor.path` is another way to get the relative path needed to later load predictor.

Above save_path is the same folder previously passed to TabularPredictor, in which all the trained models have been saved. You can train easily models on one machine and deploy them on another. Simply copy the save_path folder to the new machine and specify its new path in TabularPredictor.load().

To find out the required feature columns to make predictions, call predictor.features():

predictor.features()
['age',
 'workclass',
 'fnlwgt',
 'education',
 'education-num',
 'marital-status',
 'occupation',
 'relationship',
 'race',
 'sex',
 'capital-gain',
 'capital-loss',
 'hours-per-week',
 'native-country']

We can make a prediction on an individual example rather than a full dataset:

datapoint = test_data_nolabel.iloc[[0]]  # Note: .iloc[0] won't work because it returns pandas Series instead of DataFrame
print(datapoint)
predictor.predict(datapoint)
age workclass  fnlwgt education  education-num       marital-status  \
0   31   Private  169085      11th              7   Married-civ-spouse   

  occupation relationship    race      sex  capital-gain  capital-loss  \
0      Sales         Wife   White   Female             0             0   

   hours-per-week  native-country  
0              20   United-States
0     <=50K
Name: class, dtype: object

To output predicted class probabilities instead of predicted classes, you can use:

predictor.predict_proba(datapoint)  # returns a DataFrame that shows which probability corresponds to which class
<=50K >50K
0 0.951059 0.048941

By default, predict() and predict_proba() will utilize the model that AutoGluon thinks is most accurate, which is usually an ensemble of many individual models. Here’s how to see which model this is:

predictor.model_best
'WeightedEnsemble_L2'

We can instead specify a particular model to use for predictions (e.g. to reduce inference latency). Note that a ‘model’ in AutoGluon may refer to, for example, a single Neural Network, a bagged ensemble of many Neural Network copies trained on different training/validation splits, a weighted ensemble that aggregates the predictions of many other models, or a stacker model that operates on predictions output by other models. This is akin to viewing a Random Forest as one ‘model’ when it is in fact an ensemble of many decision trees.

Before deciding which model to use, let’s evaluate all of the models AutoGluon has previously trained on our test data:

predictor.leaderboard(test_data)
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1 0.743784 0.776399 balanced_accuracy 0.147397 0.044906 0.961846 0.147397 0.044906 0.961846 1 True 1
1 WeightedEnsemble_L2 0.743784 0.776399 balanced_accuracy 0.148645 0.045989 1.002118 0.001248 0.001083 0.040272 2 True 3
2 NeuralNetFastAI_BAG_L1 0.724629 0.741368 balanced_accuracy 1.242341 0.109789 5.542246 1.242341 0.109789 5.542246 1 True 2

The leaderboard shows each model’s predictive performance on the test data (score_test) and validation data (score_val), as well as the time required to: produce predictions for the test data (pred_time_val), produce predictions on the validation data (pred_time_val), and train only this model (fit_time). Below, we show that a leaderboard can be produced without new data (just uses the data previously reserved for validation inside fit) and can display extra information about each model:

predictor.leaderboard(extra_info=True)
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order ... hyperparameters hyperparameters_fit ag_args_fit features compile_time child_hyperparameters child_hyperparameters_fit child_ag_args_fit ancestors descendants
0 LightGBM_BAG_L1 0.776399 balanced_accuracy 0.044906 0.961846 0.044906 0.961846 1 True 1 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [workclass, capital-loss, hours-per-week, education, fnlwgt, marital-status, capital-gain, occupation, relationship, native-country, age, education-num, sex, race] None {'learning_rate': 0.05, 'num_boost_round': 200} {'num_boost_round': 83} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [WeightedEnsemble_L2]
1 WeightedEnsemble_L2 0.776399 balanced_accuracy 0.045989 1.002118 0.001083 0.040272 2 True 3 ... {'use_orig_features': False, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [LightGBM_BAG_L1] None {'ensemble_size': 25, 'subsample_size': 1000000} {'ensemble_size': 1} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [LightGBM_BAG_L1] []
2 NeuralNetFastAI_BAG_L1 0.741368 balanced_accuracy 0.109789 5.542246 0.109789 5.542246 1 True 2 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [workclass, capital-loss, hours-per-week, education, fnlwgt, marital-status, capital-gain, occupation, relationship, native-country, age, education-num, sex, race] None {'layers': None, 'emb_drop': 0.1, 'ps': 0.1, 'bs': 'auto', 'lr': 0.01, 'epochs': 'auto', 'early.stopping.min_delta': 0.0001, 'early.stopping.patience': 20, 'smoothing': 0.0, 'num_epochs': 10} {'epochs': 30, 'best_epoch': 9} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': ['text_ngram', 'text_as_category'], 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] []

3 rows × 32 columns

The expanded leaderboard shows properties like how many features are used by each model (num_features), which other models are ancestors whose predictions are required inputs for each model (ancestors), and how much memory each model and all its ancestors would occupy if simultaneously persisted (memory_size_w_ancestors). See the leaderboard documentation for full details.

To show scores for other metrics, you can specify the extra_metrics argument when passing in test_data:

predictor.leaderboard(test_data, extra_metrics=['accuracy', 'balanced_accuracy', 'log_loss'])
model score_test accuracy balanced_accuracy log_loss score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1 0.743784 0.847170 0.743784 -0.334022 0.776399 balanced_accuracy 0.141542 0.044906 0.961846 0.141542 0.044906 0.961846 1 True 1
1 WeightedEnsemble_L2 0.743784 0.847170 0.743784 -0.334022 0.776399 balanced_accuracy 0.142906 0.045989 1.002118 0.001364 0.001083 0.040272 2 True 3
2 NeuralNetFastAI_BAG_L1 0.724629 0.843792 0.724629 -0.343404 0.741368 balanced_accuracy 1.272334 0.109789 5.542246 1.272334 0.109789 5.542246 1 True 2

Notice that log_loss scores are negative. This is because metrics in AutoGluon are always shown in higher_is_better form. This means that metrics such as log_loss and root_mean_squared_error will have their signs FLIPPED, and values will be negative. This is necessary to avoid the user needing to know the metric to understand if higher is better when looking at leaderboard.

One additional caveat: It is possible that log_loss values can be -inf when computed via extra_metrics. This is because the models were not optimized with log_loss in mind during training and may have prediction probabilities giving a class 0 (particularly common with K-Nearest-Neighbors models). Because log_loss gives infinite error when the correct class was given 0 probability, this results in a score of -inf. It is therefore recommended that log_loss should not be used as a secondary metric to determine model quality. Either use log_loss as the eval_metric or avoid it altogether.

Here’s how to specify a particular model to use for prediction instead of AutoGluon’s default model-choice:

i = 0  # index of model to use
model_to_use = predictor.model_names()[i]
model_pred = predictor.predict(datapoint, model=model_to_use)
print("Prediction from %s model: %s" % (model_to_use, model_pred.iloc[0]))
Prediction from LightGBM_BAG_L1 model:  <=50K

We can easily access various information about the trained predictor or a particular model:

all_models = predictor.model_names()
model_to_use = all_models[i]
specific_model = predictor._trainer.load_model(model_to_use)

# Objects defined below are dicts of various information (not printed here as they are quite large):
model_info = specific_model.get_info()
predictor_information = predictor.info()

The predictor also remembers what metric predictions should be evaluated with, which can be done with ground truth labels as follows:

y_pred_proba = predictor.predict_proba(test_data_nolabel)
perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred_proba)

Since the label columns remains in the test_data DataFrame, we can instead use the shorthand:

perf = predictor.evaluate(test_data)

Interpretability (feature importance)

To better understand our trained predictor, we can estimate the overall importance of each feature:

predictor.feature_importance(test_data)
Computing feature importance via permutation shuffling for 14 features using 5000 rows with 5 shuffle sets...
8.6s	= Expected runtime (1.72s per shuffle set)
5.39s	= Actual runtime (Completed 5 of 5 shuffle sets)
importance stddev p_value n p99_high p99_low
marital-status 0.068704 0.004542 2.279366e-06 5 0.078057 0.059352
capital-gain 0.046431 0.002457 9.369035e-07 5 0.051489 0.041372
education-num 0.042721 0.003485 5.268617e-06 5 0.049898 0.035545
age 0.035115 0.005922 9.348413e-05 5 0.047308 0.022922
occupation 0.033699 0.007890 3.356604e-04 5 0.049945 0.017454
relationship 0.014965 0.003663 3.983866e-04 5 0.022507 0.007423
hours-per-week 0.012270 0.003750 9.287608e-04 5 0.019992 0.004548
capital-loss 0.002217 0.001260 8.531892e-03 5 0.004812 -0.000378
education 0.000319 0.000774 2.045314e-01 5 0.001912 -0.001274
native-country 0.000000 0.000000 5.000000e-01 5 0.000000 0.000000
race -0.000256 0.000268 9.500382e-01 5 0.000296 -0.000807
sex -0.000527 0.001303 7.914146e-01 5 0.002156 -0.003210
workclass -0.001594 0.002576 8.807408e-01 5 0.003709 -0.006898
fnlwgt -0.004992 0.001524 9.990763e-01 5 -0.001855 -0.008130

Computed via permutation-shuffling, these feature importance scores quantify the drop in predictive performance (of the already trained predictor) when one column’s values are randomly shuffled across rows. The top features in this list contribute most to AutoGluon’s accuracy (for predicting when/if a patient will be readmitted to the hospital). Features with non-positive importance score hardly contribute to the predictor’s accuracy, or may even be actively harmful to include in the data (consider removing these features from your data and calling fit again). These scores facilitate interpretability of the predictor’s global behavior (which features it relies on for all predictions). To get local explanations regarding which features influence a particular prediction, check out the example notebooks for explaining particular AutoGluon predictions using Shapely values.

Before making judgement on if AutoGluon is more or less interpretable than another solution, we recommend reading The Mythos of Model Interpretability by Zachary Lipton, which covers why often-claimed interpretable models such as trees and linear models are rarely meaningfully more interpretable than more advanced models.

Accelerating inference

We describe multiple ways to reduce the time it takes for AutoGluon to produce predictions.

Before providing code examples, it is important to understand that there are several ways to accelerate inference in AutoGluon. The table below lists the options in order of priority.

Optimization

Inference Speedup

Cost

Notes

refit_full

At least 8x+, up to 160x (requires bagging)

-Quality, +FitTime

Only provides speedup with bagging enabled.

persist

Up to 10x in online-inference

++MemoryUsage

If memory is not sufficient to persist model, speedup is not gained. Speedup is most effective in online-inference and is not relevant in batch inference.

infer_limit

Configurable, ~up to 50x

-Quality (Relative to speedup)

If bagging is enabled, always use refit_full if using infer_limit.

distill

~Equals combined speedup of refit_full and infer_limit set to extreme values

–Quality, ++FitTime

Not compatible with refit_full and infer_limit.

feature pruning

Typically at most 1.5x. More if willing to lower quality significantly.

-Quality?, ++FitTime

Dependent on the existence of unimportant features in data. Call predictor.feature_importance(test_data) to gauge which features could be removed.

use faster hardware

Usually at most 3x. Depends on hardware (ignoring GPU).

+Hardware

As an example, an EC2 c6i.2xlarge is ~1.6x faster than an m5.2xlarge for a similar price. Laptops in particular might be slow compared to cloud instances.

manual hyperparameters adjustment

Usually at most 2x assuming infer_limit is already specified.

—Quality?, +++UserMLExpertise

Can be very complicated and is not recommended. Potential ways to get speedups this way is to reduce the number of trees in LightGBM, XGBoost, CatBoost, RandomForest, and ExtraTrees.

manual data preprocessing

Usually at most 1.2x assuming all other optimizations are specified and setting is online-inference.

++++UserMLExpertise, ++++UserCode

Only relevant for online-inference. This is not recommended as AutoGluon’s default preprocessing is highly optimized.

If bagging is enabled (num_bag_folds>0 or num_stack_levels>0 or using ‘best_quality’ preset), the order of inference optimizations should be:

  1. refit_full

  2. persist

  3. infer_limit

If bagging is not enabled (num_bag_folds=0, num_stack_levels=0), the order of inference optimizations should be:

  1. persist

  2. infer_limit

If following these recommendations does not lead to a sufficiently fast model, you may consider the more advanced options in the table.

Keeping models in memory

By default, AutoGluon loads models into memory one at a time and only when they are needed for prediction. This strategy is robust for large stacked/bagged ensembles, but leads to slower prediction times. If you plan to repeatedly make predictions (e.g. on new datapoints one at a time rather than one large test dataset), you can first specify that all models required for inference should be loaded into memory as follows:

predictor.persist()

num_test = 20
preds = np.array(['']*num_test, dtype='object')
for i in range(num_test):
    datapoint = test_data_nolabel.iloc[[i]]
    pred_numpy = predictor.predict(datapoint, as_pandas=False)
    preds[i] = pred_numpy[0]

perf = predictor.evaluate_predictions(y_test[:num_test], preds, auxiliary_metrics=True)
print("Predictions: ", preds)

predictor.unpersist()  # free memory by clearing models, future predict() calls will load models from disk
Predictions:  [' <=50K' ' <=50K' ' >50K' ' <=50K' ' <=50K' ' >50K' ' >50K' ' >50K'
 ' <=50K' ' <=50K' ' <=50K' ' <=50K' ' <=50K' ' <=50K' ' <=50K' ' <=50K'
 ' <=50K' ' >50K' ' >50K' ' <=50K']
Persisting 2 models in memory. Models will require 0.01% of memory.
Unpersisted 2 models: ['WeightedEnsemble_L2', 'LightGBM_BAG_L1']
['WeightedEnsemble_L2', 'LightGBM_BAG_L1']

You can alternatively specify a particular model to persist via the models argument of persist(), or simply set models='all' to simultaneously load every single model that was trained during fit.

Inference speed as a fit constraint

If you know your latency constraint prior to fitting the predictor, you can specify it explicitly as a fit argument. AutoGluon will then automatically train models in a fashion that attempts to satisfy the constraint.

This constraint has two components: infer_limit and infer_limit_batch_size:

  • infer_limit is the time in seconds to predict 1 row of data. For example, infer_limit=0.05 means 50 ms per row of data, or 20 rows / second throughput.

  • infer_limit_batch_size is the amount of rows passed at once to predict when calculating per-row speed. This is very important because infer_limit_batch_size=1 (online-inference) is highly suboptimal as various operations have a fixed cost overhead regardless of data size. If you can pass your test data in bulk, you should specify infer_limit_batch_size=10000.

# At most 0.05 ms per row (20000 rows per second throughput)
infer_limit = 0.00005
# adhere to infer_limit with batches of size 10000 (batch-inference, easier to satisfy infer_limit)
infer_limit_batch_size = 10000
# adhere to infer_limit with batches of size 1 (online-inference, much harder to satisfy infer_limit)
# infer_limit_batch_size = 1  # Note that infer_limit<0.02 when infer_limit_batch_size=1 can be difficult to satisfy.
predictor_infer_limit = TabularPredictor(label=label, eval_metric=metric).fit(
    train_data=train_data,
    time_limit=30,
    infer_limit=infer_limit,
    infer_limit_batch_size=infer_limit_batch_size,
)

# NOTE: If bagging was enabled, it is important to call refit_full at this stage.
#  infer_limit assumes that the user will call refit_full after fit.
# predictor_infer_limit.refit_full()

# NOTE: To align with inference speed calculated during fit, models must be persisted.
predictor_infer_limit.persist()
# Below is an optimized version that only persists the minimum required models for prediction.
# predictor_infer_limit.persist('best')

predictor_infer_limit.leaderboard()
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_023026"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       28.02 GB / 30.95 GB (90.5%)
Disk Space Avail:   213.26 GB / 255.99 GB (83.3%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ... Time limit = 30s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023026"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    28689.47 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
1.56μs	= Feature Preprocessing Time (1 row | 10000 batch size)
Feature Preprocessing requires 3.12% of the overall inference constraint (0.05ms)
		0.048ms inference time budget remaining for models...
Data preprocessing and feature engineering runtime = 0.3s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],
}
Fitting 13 L1 models, fit_strategy="sequential" ...
Fitting model: KNeighborsUnif ... Training model for up to 29.70s of the 29.70s of remaining time.
0.725	 = Validation score   (accuracy)
0.04s	 = Training   runtime
0.01s	 = Validation runtime
3.002μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
3.002μs	 = Validation runtime (1 row | 10000 batch size)
3.002μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
3.002μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: KNeighborsDist ... Training model for up to 29.64s of the 29.64s of remaining time.
0.71	 = Validation score   (accuracy)
0.05s	 = Training   runtime
0.02s	 = Validation runtime
4.372μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
4.372μs	 = Validation runtime (1 row | 10000 batch size)
4.372μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
4.372μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: LightGBMXT ... Training model for up to 29.57s of the 29.57s of remaining time.
0.85	 = Validation score   (accuracy)
0.41s	 = Training   runtime
0.0s	 = Validation runtime
1.353μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
1.353μs	 = Validation runtime (1 row | 10000 batch size)
1.353μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
1.353μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: LightGBM ... Training model for up to 29.15s of the 29.15s of remaining time.
0.84	 = Validation score   (accuracy)
0.49s	 = Training   runtime
0.0s	 = Validation runtime
1.046μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
1.046μs	 = Validation runtime (1 row | 10000 batch size)
1.046μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
1.046μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: RandomForestGini ... Training model for up to 28.65s of the 28.64s of remaining time.
0.84	 = Validation score   (accuracy)
0.78s	 = Training   runtime
0.05s	 = Validation runtime
8.744μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
8.744μs	 = Validation runtime (1 row | 10000 batch size)
8.744μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
8.744μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: RandomForestEntr ... Training model for up to 27.80s of the 27.80s of remaining time.
0.835	 = Validation score   (accuracy)
0.67s	 = Training   runtime
0.05s	 = Validation runtime
8.809μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
8.809μs	 = Validation runtime (1 row | 10000 batch size)
8.809μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
8.809μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: CatBoost ... Training model for up to 27.07s of the 27.07s of remaining time.
0.86	 = Validation score   (accuracy)
2.01s	 = Training   runtime
0.0s	 = Validation runtime
0.761μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
0.761μs	 = Validation runtime (1 row | 10000 batch size)
0.761μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
0.761μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: ExtraTreesGini ... Training model for up to 25.04s of the 25.04s of remaining time.
0.815	 = Validation score   (accuracy)
0.69s	 = Training   runtime
0.05s	 = Validation runtime
8.779μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
8.779μs	 = Validation runtime (1 row | 10000 batch size)
8.779μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
8.779μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: ExtraTreesEntr ... Training model for up to 24.29s of the 24.29s of remaining time.
0.82	 = Validation score   (accuracy)
0.68s	 = Training   runtime
0.05s	 = Validation runtime
8.763μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
8.763μs	 = Validation runtime (1 row | 10000 batch size)
8.763μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
8.763μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: NeuralNetFastAI ... Training model for up to 23.55s of the 23.55s of remaining time.
No improvement since epoch 7: early stopping
0.84	 = Validation score   (accuracy)
1.08s	 = Training   runtime
0.01s	 = Validation runtime
0.014ms	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
0.014ms	 = Validation runtime (1 row | 10000 batch size)
0.014ms	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
0.014ms	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: XGBoost ... Training model for up to 22.45s of the 22.45s of remaining time.
0.845	 = Validation score   (accuracy)
0.22s	 = Training   runtime
0.01s	 = Validation runtime
2.087μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
2.087μs	 = Validation runtime (1 row | 10000 batch size)
2.087μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
2.087μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: NeuralNetTorch ... Training model for up to 22.21s of the 22.21s of remaining time.
0.855	 = Validation score   (accuracy)
3.46s	 = Training   runtime
0.01s	 = Validation runtime
4.361μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
4.361μs	 = Validation runtime (1 row | 10000 batch size)
4.361μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
4.361μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: LightGBMLarge ... Training model for up to 18.73s of the 18.73s of remaining time.
0.795	 = Validation score   (accuracy)
0.87s	 = Training   runtime
0.0s	 = Validation runtime
4.339μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
4.339μs	 = Validation runtime (1 row | 10000 batch size)
4.339μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
4.339μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Removing 5/13 base models to satisfy inference constraint (constraint=0.046ms) ...
0.071ms	-> 0.066ms	(KNeighborsDist)
0.066ms	-> 0.063ms	(KNeighborsUnif)
0.063ms	-> 0.059ms	(LightGBMLarge)
0.059ms	-> 0.05ms	(ExtraTreesGini)
0.05ms	-> 0.041ms	(ExtraTreesEntr)
Fitting model: WeightedEnsemble_L2 ... Training model for up to 29.70s of the 17.81s of remaining time.
Ensemble Weights: {'CatBoost': 1.0}
0.86	 = Validation score   (accuracy)
0.09s	 = Training   runtime
0.0s	 = Validation runtime
0.052μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
0.813μs	 = Validation runtime (1 row | 10000 batch size)
0.052μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
0.813μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
AutoGluon training complete, total runtime = 12.31s ... Best model: CatBoost | Estimated inference throughput: 54892.1 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023026")
Persisting 1 models in memory. Models will require 0.0% of memory.
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CatBoost 0.860 accuracy 0.003644 2.011232 0.003644 2.011232 1 True 7
1 WeightedEnsemble_L2 0.860 accuracy 0.004554 2.102227 0.000910 0.090995 2 True 14
2 NeuralNetTorch 0.855 accuracy 0.009406 3.458127 0.009406 3.458127 1 True 12
3 LightGBMXT 0.850 accuracy 0.003899 0.412800 0.003899 0.412800 1 True 3
4 XGBoost 0.845 accuracy 0.005847 0.221453 0.005847 0.221453 1 True 11
5 LightGBM 0.840 accuracy 0.003832 0.486827 0.003832 0.486827 1 True 4
6 NeuralNetFastAI 0.840 accuracy 0.008778 1.078784 0.008778 1.078784 1 True 10
7 RandomForestGini 0.840 accuracy 0.047178 0.776440 0.047178 0.776440 1 True 5
8 RandomForestEntr 0.835 accuracy 0.047157 0.672524 0.047157 0.672524 1 True 6
9 ExtraTreesEntr 0.820 accuracy 0.046283 0.679242 0.046283 0.679242 1 True 9
10 ExtraTreesGini 0.815 accuracy 0.046544 0.686368 0.046544 0.686368 1 True 8
11 LightGBMLarge 0.795 accuracy 0.004859 0.872005 0.004859 0.872005 1 True 13
12 KNeighborsUnif 0.725 accuracy 0.013559 0.036241 0.013559 0.036241 1 True 1
13 KNeighborsDist 0.710 accuracy 0.015129 0.048683 0.015129 0.048683 1 True 2

Now we can test the inference speed of the final model and check if it satisfies the inference constraints.

test_data_batch = test_data.sample(infer_limit_batch_size, replace=True, ignore_index=True)

import time
time_start = time.time()
predictor_infer_limit.predict(test_data_batch)
time_end = time.time()

infer_time_per_row = (time_end - time_start) / len(test_data_batch)
rows_per_second = 1 / infer_time_per_row
infer_time_per_row_ratio = infer_time_per_row / infer_limit
is_constraint_satisfied = infer_time_per_row_ratio <= 1

print(f'Model is able to predict {round(rows_per_second, 1)} rows per second. (User-specified Throughput = {1 / infer_limit})')
print(f'Model uses {round(infer_time_per_row_ratio * 100, 1)}% of infer_limit time per row.')
print(f'Model satisfies inference constraint: {is_constraint_satisfied}')
Model is able to predict 379221.5 rows per second. (User-specified Throughput = 20000.0)
Model uses 5.3% of infer_limit time per row.
Model satisfies inference constraint: True

Using smaller ensemble or faster model for prediction

Without having to retrain any models, one can construct alternative ensembles that aggregate individual models’ predictions with different weighting schemes. These ensembles become smaller (and hence faster for prediction) if they assign nonzero weight to less models. You can produce a wide variety of ensembles with different accuracy-speed tradeoffs like this:

additional_ensembles = predictor.fit_weighted_ensemble(expand_pareto_frontier=True)
print("Alternative ensembles you can use for prediction:", additional_ensembles)

predictor.leaderboard(only_pareto_frontier=True)
Alternative ensembles you can use for prediction: ['WeightedEnsemble_L2Best']
Fitting model: WeightedEnsemble_L2Best ...
Ensemble Weights: {'LightGBM_BAG_L1': 1.0}
0.7764	 = Validation score   (balanced_accuracy)
0.02s	 = Training   runtime
0.0s	 = Validation runtime
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1 0.776399 balanced_accuracy 0.044906 0.961846 0.044906 0.961846 1 True 1

The resulting leaderboard will contain the most accurate model for a given inference-latency. You can select whichever model exhibits acceptable latency from the leaderboard and use it for prediction.

model_for_prediction = additional_ensembles[0]
predictions = predictor.predict(test_data, model=model_for_prediction)
predictor.delete_models(models_to_delete=additional_ensembles, dry_run=False)  # delete these extra models so they don't affect rest of tutorial
Deleting model WeightedEnsemble_L2Best. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best will be removed.

Collapsing bagged ensembles via refit_full

For an ensemble predictor trained with bagging (as done above), recall there are ~10 bagged copies of each individual model trained on different train/validation folds. We can collapse this bag of ~10 models into a single model that’s fit to the full dataset, which can greatly reduce its memory/latency requirements (but may also reduce accuracy). Below we refit such a model for each original model but you can alternatively do this for just a particular model by specifying the model argument of refit_full().

refit_model_map = predictor.refit_full()
print("Name of each refit-full model corresponding to a previous bagged ensemble:")
print(refit_model_map)
predictor.leaderboard(test_data)
Name of each refit-full model corresponding to a previous bagged ensemble:
{'LightGBM_BAG_L1': 'LightGBM_BAG_L1_FULL', 'NeuralNetFastAI_BAG_L1': 'NeuralNetFastAI_BAG_L1_FULL', 'WeightedEnsemble_L2': 'WeightedEnsemble_L2_FULL'}
Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...
	Models trained in this way will have the suffix "_FULL" and have NaN validation score.
	This process is not bound by time_limit, but should take less time than the original `predictor.fit` call.
	To learn more, refer to the `.refit_full` method docstring which explains how "_FULL" models differ from normal models.
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L1_FULL ...
0.32s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: NeuralNetFastAI_BAG_L1_FULL ...
Metric balanced_accuracy is not supported by this model - using log_loss instead
Stopping at the best epoch learned earlier - 9.
0.39s	 = Training   runtime
Fitting model: WeightedEnsemble_L2_FULL | Skipping fit via cloning parent ...
Ensemble Weights: {'LightGBM_BAG_L1': 1.0}
0.04s	 = Training   runtime
Updated best model to "WeightedEnsemble_L2_FULL" (Previously "WeightedEnsemble_L2"). AutoGluon will default to using "WeightedEnsemble_L2_FULL" for predict() and predict_proba().
Refit complete, total runtime = 0.78s ... Best model: "WeightedEnsemble_L2_FULL"
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1_FULL 0.750092 NaN balanced_accuracy 0.027850 NaN 0.319060 0.027850 NaN 0.319060 1 True 4
1 WeightedEnsemble_L2_FULL 0.750092 NaN balanced_accuracy 0.029138 NaN 0.359332 0.001288 NaN 0.040272 2 True 6
2 LightGBM_BAG_L1 0.743784 0.776399 balanced_accuracy 0.140372 0.044906 0.961846 0.140372 0.044906 0.961846 1 True 1
3 WeightedEnsemble_L2 0.743784 0.776399 balanced_accuracy 0.141636 0.045989 1.002118 0.001264 0.001083 0.040272 2 True 3
4 NeuralNetFastAI_BAG_L1 0.724629 0.741368 balanced_accuracy 1.258992 0.109789 5.542246 1.258992 0.109789 5.542246 1 True 2
5 NeuralNetFastAI_BAG_L1_FULL 0.700878 NaN balanced_accuracy 0.295795 NaN 0.386061 0.295795 NaN 0.386061 1 True 5

This adds the refit-full models to the leaderboard and we can opt to use any of them for prediction just like any other model. Note pred_time_test and pred_time_val list the time taken to produce predictions with each model (in seconds) on the test/validation data. Since the refit-full models were trained using all of the data, there is no internal validation score (score_val) available for them. You can also call refit_full() with non-bagged models to refit the same models to your full dataset (there won’t be memory/latency gains in this case but test accuracy may improve).

Model distillation

While computationally-favorable, single individual models will usually have lower accuracy than weighted/stacked/bagged ensembles. Model Distillation offers one way to retain the computational benefits of a single model, while enjoying some of the accuracy-boost that comes with ensembling. The idea is to train the individual model (which we can call the student) to mimic the predictions of the full stack ensemble (the teacher). Like refit_full(), the distill() function will produce additional models we can opt to use for prediction.

student_models = predictor.distill(time_limit=30)  # specify much longer time limit in real applications
print(student_models)
preds_student = predictor.predict(test_data_nolabel, model=student_models[0])
print(f"predictions from {student_models[0]}:", list(preds_student)[:5])
predictor.leaderboard(test_data)
['RandomForestMSE_DSTL', 'WeightedEnsemble_L2_DSTL']
predictions from RandomForestMSE_DSTL: [' <=50K', ' <=50K', ' >50K', ' <=50K', ' <=50K']
Distilling with teacher='WeightedEnsemble_L2_FULL', teacher_preds=soft, augment_method=spunge ...
SPUNGE: Augmenting training data with 4000 synthetic samples for distillation...
Distilling with each of these student models: ['LightGBM_DSTL', 'CatBoost_DSTL', 'RandomForestMSE_DSTL', 'NeuralNetTorch_DSTL']
Fitting 4 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_DSTL ... Training model for up to 30.00s of the 30.00s of remaining time.
Warning: Exception caused LightGBM_DSTL to fail during training... Skipping this model.
pandas dtypes must be int, float or bool.
Fields with bad pandas dtypes: workclass: object, education: object, marital-status: object, occupation: object, relationship: object, race: object, native-country: object
Detailed Traceback:
Traceback (most recent call last):
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 2106, in _train_and_save
    model = self._train_single(**model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 1993, in _train_single
    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, total_resources=total_resources, **model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/models/abstract/abstract_model.py", line 925, in fit
    out = self._fit(**kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/lgb/lgb_model.py", line 283, in _fit
    self.model = train_lgb_model(early_stopping_callback_kwargs=early_stopping_callback_kwargs, **train_params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/lgb/lgb_utils.py", line 134, in train_lgb_model
    return lgb.train(**train_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/engine.py", line 282, in train
    booster = Booster(params=params, train_set=train_set)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 3637, in __init__
    train_set.construct()
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 2576, in construct
    self._lazy_init(
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 2106, in _lazy_init
    data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(
                                                                       ^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 848, in _data_from_pandas
    _pandas_to_numpy(data, target_dtype=target_dtype),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 794, in _pandas_to_numpy
    _check_for_bad_pandas_dtypes(data.dtypes)
  File "/home/ci/opt/venv/lib/python3.11/site-packages/lightgbm/basic.py", line 784, in _check_for_bad_pandas_dtypes
    raise ValueError(
ValueError: pandas dtypes must be int, float or bool.
Fields with bad pandas dtypes: workclass: object, education: object, marital-status: object, occupation: object, relationship: object, race: object, native-country: object
Fitting model: CatBoost_DSTL ... Training model for up to 29.50s of the 29.50s of remaining time.
Warning: Exception caused CatBoost_DSTL to fail during training... Skipping this model.
features data: pandas.DataFrame column 'workclass' has dtype 'category' but is not in  cat_features list
Detailed Traceback:
Traceback (most recent call last):
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 2106, in _train_and_save
    model = self._train_single(**model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 1993, in _train_single
    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, total_resources=total_resources, **model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/models/abstract/abstract_model.py", line 925, in fit
    out = self._fit(**kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/catboost/catboost_model.py", line 136, in _fit
    X_val = Pool(data=X_val, label=y_val, cat_features=cat_features, weight=sample_weight_val)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/catboost/core.py", line 855, in __init__
    self._init(data, label, cat_features, text_features, embedding_features, embedding_features_data, pairs, graph, weight,
  File "/home/ci/opt/venv/lib/python3.11/site-packages/catboost/core.py", line 1491, in _init
    self._init_pool(data, label, cat_features, text_features, embedding_features, embedding_features_data, pairs, graph, weight,
  File "_catboost.pyx", line 4339, in _catboost._PoolBase._init_pool
  File "_catboost.pyx", line 4391, in _catboost._PoolBase._init_pool
  File "_catboost.pyx", line 4200, in _catboost._PoolBase._init_features_order_layout_pool
  File "_catboost.pyx", line 3083, in _catboost._set_features_order_data_pd_data_frame
_catboost.CatBoostError: features data: pandas.DataFrame column 'workclass' has dtype 'category' but is not in  cat_features list
Fitting model: RandomForestMSE_DSTL ... Training model for up to 29.27s of the 29.27s of remaining time.
/home/ci/autogluon/tabular/src/autogluon/tabular/models/rf/rf_model.py:80: FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
  X = X.fillna(0).to_numpy(dtype=np.float32)
Note: model has different eval_metric than default.
-0.1103	 = Validation score   (-mean_squared_error)
1.37s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: NeuralNetTorch_DSTL ... Training model for up to 27.74s of the 27.74s of remaining time.
Warning: Exception caused NeuralNetTorch_DSTL to fail during training... Skipping this model.
Found array with 0 feature(s) (shape=(4800, 0)) while a minimum of 1 is required.
Detailed Traceback:
Traceback (most recent call last):
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 2106, in _train_and_save
    model = self._train_single(**model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/trainer/abstract_trainer.py", line 1993, in _train_single
    model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test, total_resources=total_resources, **model_fit_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/core/src/autogluon/core/models/abstract/abstract_model.py", line 925, in fit
    out = self._fit(**kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py", line 201, in _fit
    train_dataset = self._generate_dataset(X, y, train_params=processor_kwargs, is_train=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py", line 669, in _generate_dataset
    dataset = self._process_train_data(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py", line 744, in _process_train_data
    df = self.processor.fit_transform(df)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 316, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/base.py", line 1473, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/compose/_column_transformer.py", line 976, in fit_transform
    result = self._call_func_on_transformers(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/compose/_column_transformer.py", line 885, in _call_func_on_transformers
    return Parallel(n_jobs=self.n_jobs)(jobs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/utils/parallel.py", line 74, in __call__
    return super().__call__(iterable_with_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/joblib/parallel.py", line 1918, in __call__
    return output if self.return_generator else list(output)
                                                ^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/joblib/parallel.py", line 1847, in _get_sequential_output
    res = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/utils/parallel.py", line 136, in __call__
    return self.function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/pipeline.py", line 1310, in _fit_transform_one
    res = transformer.fit_transform(X, y, **params.get("fit_transform", {}))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/base.py", line 1473, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/pipeline.py", line 541, in fit_transform
    return last_step.fit_transform(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 316, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/base.py", line 1098, in fit_transform
    return self.fit(X, **fit_params).transform(X)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py", line 726, in fit
    self._fit(X, handle_unknown="ignore")
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py", line 194, in _fit
    X_list, n_samples, n_features = self._check_X(X)
                                    ^^^^^^^^^^^^^^^^
  File "/home/ci/autogluon/tabular/src/autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py", line 165, in _check_X
    X_temp = check_array(X, dtype=None, force_all_finite=False)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 1096, in check_array
    raise ValueError(
ValueError: Found array with 0 feature(s) (shape=(4800, 0)) while a minimum of 1 is required.
Distilling with each of these student models: ['WeightedEnsemble_L2_DSTL']
Fitting model: WeightedEnsemble_L2_DSTL ... Training model for up to 30.00s of the 27.47s of remaining time.
Ensemble Weights: {'RandomForestMSE_DSTL': 1.0}
Note: model has different eval_metric than default.
-0.1103	 = Validation score   (-mean_squared_error)
0.0s	 = Training   runtime
0.0s	 = Validation runtime
Distilled model leaderboard:
model  score_val         eval_metric  pred_time_val  fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0      RandomForestMSE_DSTL   0.718252  mean_squared_error        0.04559  1.366632                 0.04559           1.366632            1       True          7
1  WeightedEnsemble_L2_DSTL   0.718252  mean_squared_error        0.04622  1.369297                 0.00063           0.002665            2       True          8
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBM_BAG_L1_FULL 0.750092 NaN balanced_accuracy 0.026449 NaN 0.319060 0.026449 NaN 0.319060 1 True 4
1 WeightedEnsemble_L2_FULL 0.750092 NaN balanced_accuracy 0.027704 NaN 0.359332 0.001254 NaN 0.040272 2 True 6
2 LightGBM_BAG_L1 0.743784 0.776399 balanced_accuracy 0.142558 0.044906 0.961846 0.142558 0.044906 0.961846 1 True 1
3 WeightedEnsemble_L2 0.743784 0.776399 balanced_accuracy 0.143791 0.045989 1.002118 0.001233 0.001083 0.040272 2 True 3
4 RandomForestMSE_DSTL 0.732074 0.718252 mean_squared_error 0.174362 0.045590 1.366632 0.174362 0.045590 1.366632 1 True 7
5 WeightedEnsemble_L2_DSTL 0.732074 0.718252 mean_squared_error 0.176342 0.046220 1.369297 0.001980 0.000630 0.002665 2 True 8
6 NeuralNetFastAI_BAG_L1 0.724629 0.741368 balanced_accuracy 1.242111 0.109789 5.542246 1.242111 0.109789 5.542246 1 True 2
7 NeuralNetFastAI_BAG_L1_FULL 0.700878 NaN balanced_accuracy 0.312339 NaN 0.386061 0.312339 NaN 0.386061 1 True 5

Faster presets or hyperparameters

Instead of trying to speed up a cumbersome trained model at prediction time, if you know inference latency or memory will be an issue at the outset, then you can adjust the training process accordingly to ensure fit() does not produce unwieldy models.

One option is to specify more lightweight presets:

presets = ['good_quality', 'optimize_for_deployment']
predictor_light = TabularPredictor(label=label, eval_metric=metric).fit(train_data, presets=presets, time_limit=30)
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_023055"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       27.48 GB / 30.95 GB (88.8%)
Disk Space Avail:   213.12 GB / 255.99 GB (83.3%)
===================================================
Presets specified: ['good_quality', 'optimize_for_deployment']
Setting dynamic_stacking from 'auto' to True. Reason: Enable dynamic_stacking when use_bag_holdout is disabled. (use_bag_holdout=False)
Stack configuration (auto_stack=True): num_stack_levels=1, num_bag_folds=8, num_bag_sets=1
Note: `save_bag_folds=False`! This will greatly reduce peak disk usage during fit (by ~8x), but runs the risk of an out-of-memory error during model refit if memory is small relative to the data size.
	You can avoid this risk by setting `save_bag_folds=True`.
DyStack is enabled (dynamic_stacking=True). AutoGluon will try to determine whether the input data is affected by stacked overfitting and enable or disable stacking as a consequence.
This is used to identify the optimal `num_stack_levels` value. Copies of AutoGluon will be fit on subsets of the data. Then holdout validation data is used to detect stacked overfitting.
Running DyStack for up to 7s of the 30s of remaining time (25%).
Context path: "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/ds_sub_fit/sub_fit_ho"
Leaderboard on holdout data (DyStack):
model  score_holdout  score_val eval_metric  pred_time_test pred_time_val  fit_time  pred_time_test_marginal pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0    LightGBMXT_BAG_L1_FULL       0.866071   0.862613    accuracy        0.006509          None  0.433840                 0.006509                   None           0.433840            1       True          1
1  WeightedEnsemble_L3_FULL       0.866071   0.862613    accuracy        0.007941          None  0.437421                 0.001432                   None           0.003581            3       True          4
2  WeightedEnsemble_L2_FULL       0.866071   0.862613    accuracy        0.008017          None  0.437812                 0.001508                   None           0.003972            2       True          3
3      LightGBM_BAG_L1_FULL       0.839286   0.861486    accuracy        0.005775          None  0.164539                 0.005775                   None           0.164539            1       True          2
1	 = Optimal   num_stack_levels (Stacked Overfitting Occurred: False)
12s	 = DyStack   runtime |	18s	 = Remaining runtime
Starting main fit with num_stack_levels=1.
	For future fit calls on this dataset, you can skip DyStack to save time: `predictor.fit(..., dynamic_stacking=False, num_stack_levels=1)`
Beginning AutoGluon training ... Time limit = 18s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    27956.76 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.11s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'max_depth': 15, 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'max_depth': 15, 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'max_depth': 15, 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'max_depth': 15, 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'max_depth': 15, 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'max_depth': 15, 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
}
AutoGluon will fit 2 stack levels (L1 to L2) ...
Fitting 11 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 11.96s of the 17.94s of remaining time.
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.01%)
0.859	 = Validation score   (accuracy)
0.74s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: LightGBM_BAG_L1 ... Training model for up to 7.83s of the 13.81s of remaining time.
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.01%)
0.859	 = Validation score   (accuracy)
0.97s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: RandomForestGini_BAG_L1 ... Training model for up to 3.12s of the 9.10s of remaining time.
0.835	 = Validation score   (accuracy)
0.84s	 = Training   runtime
0.11s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L1 ... Training model for up to 2.14s of the 8.12s of remaining time.
0.835	 = Validation score   (accuracy)
0.59s	 = Training   runtime
0.11s	 = Validation runtime
Fitting model: CatBoost_BAG_L1 ... Training model for up to 1.41s of the 7.40s of remaining time.
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.02%)
Time limit exceeded... Skipping CatBoost_BAG_L1.
Fitting model: WeightedEnsemble_L2 ... Training model for up to 17.94s of the 2.97s of remaining time.
Ensemble Weights: {'LightGBMXT_BAG_L1': 1.0}
0.859	 = Validation score   (accuracy)
0.05s	 = Training   runtime
0.0s	 = Validation runtime
Fitting 11 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L2 ... Training model for up to 2.91s of the 2.90s of remaining time.
Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.01%)
0.862	 = Validation score   (accuracy)
0.64s	 = Training   runtime
0.04s	 = Validation runtime
Fitting model: WeightedEnsemble_L3 ... Training model for up to 17.94s of the -1.23s of remaining time.
Ensemble Weights: {'LightGBMXT_BAG_L2': 1.0}
0.862	 = Validation score   (accuracy)
0.06s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 19.36s ... Best model: WeightedEnsemble_L3 | Estimated inference throughput: 743.8 rows/s (125 batch size)
Automatically performing refit_full as a post-fit operation (due to `.fit(..., refit_full=True)`
Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...
	Models trained in this way will have the suffix "_FULL" and have NaN validation score.
	This process is not bound by time_limit, but should take less time than the original `predictor.fit` call.
	To learn more, refer to the `.refit_full` method docstring which explains how "_FULL" models differ from normal models.
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L1_FULL ...
2025-01-07 02:31:27,025	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,026	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,028	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,029	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,034	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,036	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
2025-01-07 02:31:27,038	ERROR worker.py:422 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): The worker died unexpectedly while executing this task. Check python-core-worker-*.log files for more information.
0.68s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L1_FULL ...
0.3s	 = Training   runtime
Fitting model: RandomForestGini_BAG_L1_FULL | Skipping fit via cloning parent ...
0.84s	 = Training   runtime
0.11s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L1_FULL | Skipping fit via cloning parent ...
0.59s	 = Training   runtime
0.11s	 = Validation runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L2_FULL ...
0.24s	 = Training   runtime
Updated best model to "LightGBMXT_BAG_L2_FULL" (Previously "WeightedEnsemble_L3"). AutoGluon will default to using "LightGBMXT_BAG_L2_FULL" for predict() and predict_proba().
Refit complete, total runtime = 1.33s ... Best model: "LightGBMXT_BAG_L2_FULL"
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (1000 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
Deleting model LightGBMXT_BAG_L1. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/LightGBMXT_BAG_L1 will be removed.
Deleting model LightGBM_BAG_L1. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/LightGBM_BAG_L1 will be removed.
Deleting model RandomForestGini_BAG_L1. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/RandomForestGini_BAG_L1 will be removed.
Deleting model RandomForestEntr_BAG_L1. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/RandomForestEntr_BAG_L1 will be removed.
Deleting model WeightedEnsemble_L2. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/WeightedEnsemble_L2 will be removed.
Deleting model LightGBMXT_BAG_L2. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/LightGBMXT_BAG_L2 will be removed.
Deleting model WeightedEnsemble_L3. All files under /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055/models/WeightedEnsemble_L3 will be removed.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023055")

Another option is to specify more lightweight hyperparameters:

predictor_light = TabularPredictor(label=label, eval_metric=metric).fit(train_data, hyperparameters='very_light', time_limit=30)
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_023128"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       27.43 GB / 30.95 GB (88.6%)
Disk Space Avail:   213.10 GB / 255.99 GB (83.2%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ... Time limit = 30s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023128"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    28093.52 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.11s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
}
Fitting 7 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT ... Training model for up to 29.89s of the 29.88s of remaining time.
0.85	 = Validation score   (accuracy)
0.31s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: LightGBM ... Training model for up to 29.56s of the 29.56s of remaining time.
0.84	 = Validation score   (accuracy)
0.42s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: CatBoost ... Training model for up to 29.13s of the 29.13s of remaining time.
0.86	 = Validation score   (accuracy)
2.06s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: NeuralNetFastAI ... Training model for up to 27.06s of the 27.06s of remaining time.
No improvement since epoch 7: early stopping
0.84	 = Validation score   (accuracy)
0.91s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: XGBoost ... Training model for up to 26.13s of the 26.13s of remaining time.
0.845	 = Validation score   (accuracy)
0.19s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: NeuralNetTorch ... Training model for up to 25.92s of the 25.92s of remaining time.
0.855	 = Validation score   (accuracy)
2.47s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: LightGBMLarge ... Training model for up to 23.44s of the 23.44s of remaining time.
0.795	 = Validation score   (accuracy)
0.8s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 29.89s of the 22.58s of remaining time.
Ensemble Weights: {'CatBoost': 1.0}
0.86	 = Validation score   (accuracy)
0.07s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 7.51s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 40056.4 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023128")

Here you can set hyperparameters to either ‘light’, ‘very_light’, or ‘toy’ to obtain progressively smaller (but less accurate) models and predictors. Advanced users may instead try manually specifying particular models’ hyperparameters in order to make them faster/smaller.

Finally, you may also exclude specific unwieldy models from being trained at all. Below we exclude models that tend to be slower (K Nearest Neighbors, Neural Networks):

excluded_model_types = ['KNN', 'NN_TORCH']
predictor_light = TabularPredictor(label=label, eval_metric=metric).fit(train_data, excluded_model_types=excluded_model_types, time_limit=30)
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_023135"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Memory Avail:       27.41 GB / 30.95 GB (88.6%)
Disk Space Avail:   213.10 GB / 255.99 GB (83.2%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ... Time limit = 30s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023135"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory:                    28070.80 MB
Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.12s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],
}
Excluded models: ['KNN', 'NN_TORCH'] (Specified by `excluded_model_types`)
Fitting 10 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT ... Training model for up to 29.88s of the 29.88s of remaining time.
0.85	 = Validation score   (accuracy)
0.33s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: LightGBM ... Training model for up to 29.54s of the 29.54s of remaining time.
0.84	 = Validation score   (accuracy)
0.42s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: RandomForestGini ... Training model for up to 29.11s of the 29.11s of remaining time.
0.84	 = Validation score   (accuracy)
0.66s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: RandomForestEntr ... Training model for up to 28.38s of the 28.38s of remaining time.
0.835	 = Validation score   (accuracy)
0.61s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: CatBoost ... Training model for up to 27.72s of the 27.71s of remaining time.
0.86	 = Validation score   (accuracy)
1.99s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: ExtraTreesGini ... Training model for up to 25.72s of the 25.72s of remaining time.
0.815	 = Validation score   (accuracy)
0.57s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: ExtraTreesEntr ... Training model for up to 25.09s of the 25.09s of remaining time.
0.82	 = Validation score   (accuracy)
0.6s	 = Training   runtime
0.05s	 = Validation runtime
Fitting model: NeuralNetFastAI ... Training model for up to 24.42s of the 24.42s of remaining time.
No improvement since epoch 7: early stopping
0.84	 = Validation score   (accuracy)
0.92s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: XGBoost ... Training model for up to 23.48s of the 23.48s of remaining time.
0.845	 = Validation score   (accuracy)
0.2s	 = Training   runtime
0.01s	 = Validation runtime
Fitting model: LightGBMLarge ... Training model for up to 23.27s of the 23.27s of remaining time.
0.795	 = Validation score   (accuracy)
0.79s	 = Training   runtime
0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 29.88s of the 22.43s of remaining time.
Ensemble Weights: {'RandomForestGini': 0.429, 'CatBoost': 0.286, 'LightGBMXT': 0.143, 'ExtraTreesEntr': 0.143}
0.875	 = Validation score   (accuracy)
0.08s	 = Training   runtime
0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 7.68s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 1957.5 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250107_023135")

(Advanced) Cache preprocessed data

If you are repeatedly predicting on the same data you can cache the preprocessed version of the data and directly send the preprocessed data to predictor.predict for faster inference:

test_data_preprocessed = predictor.transform_features(test_data)

# The following call will be faster than a normal predict call because we are skipping the preprocessing stage.
predictions = predictor.predict(test_data_preprocessed, transform_features=False)

Note that this is only useful in situations where you are repeatedly predicting on the same data. If this significantly speeds up your use-case, consider whether your current approach makes sense or if a cache on the predictions is a better solution.

(Advanced) Disable preprocessing

If you would rather do data preprocessing outside of TabularPredictor, you can disable TabularPredictor’s preprocessing entirely via:

predictor.fit(..., feature_generator=None, feature_metadata=YOUR_CUSTOM_FEATURE_METADATA)

Be warned that this removes ALL guardrails on data sanitization. It is very likely that you will run into errors doing this unless you are very familiar with AutoGluon.

One instance where this can be helpful is if you have many problems that re-use the exact same data with the exact same features. If you had 30 tasks that re-use the same features, you could fit a autogluon.features feature generator once on the data, and then when you need to predict on the 30 tasks, preprocess the data only once and then send the preprocessed data to all 30 predictors.

If you encounter memory issues

To reduce memory usage during training, you may try each of the following strategies individually or combinations of them (these may harm accuracy):

  • In fit(), set excluded_model_types = ['KNN', 'XT' ,'RF'] (or some subset of these models).

  • Try different presets in fit().

  • In fit(), set hyperparameters = 'light' or hyperparameters = 'very_light'.

  • Text fields in your table require substantial memory for N-gram featurization. To mitigate this in fit(), you can either: (1) add 'ignore_text' to your presets list (to ignore text features), or (2) specify the argument:

from sklearn.feature_extraction.text import CountVectorizer
from autogluon.features.generators import AutoMLPipelineFeatureGenerator
feature_generator = AutoMLPipelineFeatureGenerator(vectorizer=CountVectorizer(min_df=30, ngram_range=(1, 3), max_features=MAX_NGRAM, dtype=np.uint8))

for example using MAX_NGRAM = 1000 (try various values under 10000 to reduce the number of N-gram features used to represent each text field)

In addition to reducing memory usage, many of the above strategies can also be used to reduce training times.

To reduce memory usage during inference:

  • If trying to produce predictions for a large test dataset, break the test data into smaller chunks as demonstrated in FAQ.

  • If models have been previously persisted in memory but inference-speed is not a major concern, call predictor.unpersist().

  • If models have been previously persisted in memory, bagging was used in fit(), and inference-speed is a concern: call predictor.refit_full() and use one of the refit-full models for prediction (ensure this is the only model persisted in memory).

If you encounter disk space issues

To reduce disk usage, you may try each of the following strategies individually or combinations of them:

  • Make sure to delete all predictor.path folders from previous fit() runs! These can eat up your free space if you call fit() many times. If you didn’t specify path, AutoGluon still automatically saved its models to a folder called: “AutogluonModels/ag-[TIMESTAMP]”, where TIMESTAMP records when fit() was called, so make sure to also delete these folders if you run low on free space.

  • Call predictor.save_space() to delete auxiliary files produced during fit().

  • Call predictor.delete_models(models_to_keep='best', dry_run=False) if you only intend to use this predictor for inference going forward (will delete files required for non-prediction-related functionality like fit_summary).

  • In fit(), you can add 'optimize_for_deployment' to the presets list, which will automatically invoke the previous two strategies after training.

  • Most of the above strategies to reduce memory usage will also reduce disk usage (but may harm accuracy).

References

The following paper describes how AutoGluon internally operates on tabular data:

Erickson et al. AutoGluon-Tabular: Robust and Accurate AutoML for Structured Data. Arxiv, 2020.

Next Steps

If you are interested in deployment optimization, refer to the Predicting Columns in a Table - Deployment Optimization tutorial.