AutoGluon Tabular - In Depth

Open In Colab Open In SageMaker Studio Lab

Tip: If you are new to AutoGluon, review Predicting Columns in a Table - Quick Start to learn the basics of the AutoGluon API. To learn how to add your own custom models to the set that AutoGluon trains, tunes, and ensembles, review Adding a custom model to AutoGluon.

This tutorial describes how you can exert greater control when using AutoGluon’s fit() or predict(). Recall that to maximize predictive performance, you should first try TabularPredictor() and fit() with all default arguments. Then, consider non-default arguments for TabularPredictor(eval_metric=...), and fit(presets=...). Later, you can experiment with other arguments to fit() covered in this in-depth tutorial like hyperparameters, num_stack_levels, num_bag_folds, num_bag_sets, etc.

Start by importing AutoGluon’s TabularPredictor and TabularDataset, and loading the data.

from autogluon.tabular import TabularDataset, TabularPredictor

train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
subsample_size = 1000  # subsample subset of data for faster demo, try setting this to much larger values
train_data = train_data.sample(n=subsample_size, random_state=0)
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')

label = 'class'  # Now lets predict the "class" column (binary classification)
metric = "balanced_accuracy"  # Specify the evaluation metric you want to optimize
X_test = test_data.drop(columns=[label])
y_test = test_data[label]

train_data.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country class
6118 51 Private 39264 Some-college 10 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States >50K
23204 58 Private 51662 10th 6 Married-civ-spouse Other-service Wife White Female 0 0 8 United-States <=50K
29590 40 Private 326310 Some-college 10 Married-civ-spouse Craft-repair Husband White Male 0 0 44 United-States <=50K
18116 37 Private 222450 HS-grad 9 Never-married Sales Not-in-family White Male 0 2339 40 El-Salvador <=50K
33964 62 Private 109190 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 15024 0 40 United-States >50K

Model ensembling with stacking/bagging

Two methods to boost predictive performance are bagging and stack-ensembling. You’ll often see performance improve if you specify num_bag_folds = 8, num_stack_levels = 1 in the call to fit(), but this will increase training times and memory/disk usage.

save_path = 'agModels-predictClass'  # folder where to store trained models

predictor = TabularPredictor(label=label, eval_metric=metric, path=save_path).fit(train_data,
    calibrate_decision_threshold=False,  # Disabling for demonstration in next section
    num_bag_folds=8, num_bag_sets=1, num_stack_levels=1,  # or simply set `auto_stack=True` or a preset such as `presets="best"`
    time_limit=180,  # a brief 3-minute time limit for demonstration
)
predictor.leaderboard(test_data)
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.5.0b20251219
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.9.1+cu128
CUDA Version:       12.8
GPU Memory:         GPU 0: 14.57/14.57 GB
Total GPU Memory:   Free: 14.57 GB, Allocated: 0.00 GB, Total: 14.57 GB
GPU Count:          1
Memory Avail:       28.49 GB / 30.95 GB (92.1%)
Disk Space Avail:   202.55 GB / 255.99 GB (79.1%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='extreme'  : New in v1.5: The state-of-the-art for tabular data. Massively better than 'best' on datasets <100000 samples by using new Tabular Foundation Models (TFMs) meta-learned on https://tabarena.ai: TabPFNv2, TabICL, Mitra, TabDPT, and TabM. Absolute best accuracy. Requires a GPU. Recommended 64 GB CPU memory and 32+ GB GPU memory.
	presets='best'     : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='best_v150': New in v1.5: Better quality than 'best' and 5x+ faster to train. Give it a try!
	presets='high'     : Strong accuracy with fast inference speed.
	presets='high_v150': New in v1.5: Better quality than 'high' and 5x+ faster to train. Give it a try!
	presets='good'     : Good accuracy with very fast inference speed.
	presets='medium'   : Fast training time, ideal for initial prototyping.
Using hyperparameters preset: hyperparameters='default'
Beginning AutoGluon training ... Time limit = 180s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [' >50K', ' <=50K']
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
	Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29151.89 MB
	Train Data (Original)  Memory Usage: 0.50 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.09s ...
AutoGluon will gauge predictive performance using evaluation metric: 'balanced_accuracy'
	To change this, specify the eval_metric parameter of Predictor()
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
}
AutoGluon will fit 2 stack levels (L1 to L2) ...
Fitting 11 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 119.91s of the 179.90s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.06%)
/home/ci/opt/venv/lib/python3.12/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
	0.7788	 = Validation score   (balanced_accuracy)
	1.63s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: LightGBM_BAG_L1 ... Training model for up to 111.69s of the 171.69s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.07%)
	0.7797	 = Validation score   (balanced_accuracy)
	1.78s	 = Training   runtime
	0.05s	 = Validation runtime
Fitting model: RandomForestGini_BAG_L1 ... Training model for up to 106.35s of the 166.35s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7482	 = Validation score   (balanced_accuracy)
	0.71s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L1 ... Training model for up to 105.49s of the 165.49s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7567	 = Validation score   (balanced_accuracy)
	0.64s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: CatBoost_BAG_L1 ... Training model for up to 104.71s of the 164.71s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=1.40%)
	0.7687	 = Validation score   (balanced_accuracy)
	7.08s	 = Training   runtime
	0.04s	 = Validation runtime
Fitting model: ExtraTreesGini_BAG_L1 ... Training model for up to 94.49s of the 154.49s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7515	 = Validation score   (balanced_accuracy)
	0.65s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: ExtraTreesEntr_BAG_L1 ... Training model for up to 93.69s of the 153.69s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7306	 = Validation score   (balanced_accuracy)
	0.64s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 92.90s of the 152.89s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.00%)
	0.7251	 = Validation score   (balanced_accuracy)
	5.79s	 = Training   runtime
	0.11s	 = Validation runtime
Fitting model: XGBoost_BAG_L1 ... Training model for up to 83.85s of the 143.85s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.05%)
	0.7862	 = Validation score   (balanced_accuracy)
	1.58s	 = Training   runtime
	0.08s	 = Validation runtime
Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 78.03s of the 138.02s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.00%)
	0.8095	 = Validation score   (balanced_accuracy)
	14.54s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 60.11s of the 120.11s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.26%)
	0.7473	 = Validation score   (balanced_accuracy)
	2.59s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 179.91s of the 114.05s of remaining time.
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/26.3 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.24s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting 11 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L2 ... Training model for up to 113.80s of the 113.79s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.07%)
	0.7862	 = Validation score   (balanced_accuracy)
	1.63s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: LightGBM_BAG_L2 ... Training model for up to 108.80s of the 108.80s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.07%)
	0.8017	 = Validation score   (balanced_accuracy)
	2.48s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: RandomForestGini_BAG_L2 ... Training model for up to 102.73s of the 102.72s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.766	 = Validation score   (balanced_accuracy)
	0.66s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L2 ... Training model for up to 101.92s of the 101.91s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7763	 = Validation score   (balanced_accuracy)
	0.65s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: CatBoost_BAG_L2 ... Training model for up to 101.13s of the 101.12s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=1.44%)
	0.7939	 = Validation score   (balanced_accuracy)
	6.79s	 = Training   runtime
	0.05s	 = Validation runtime
Fitting model: ExtraTreesGini_BAG_L2 ... Training model for up to 91.10s of the 91.09s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7483	 = Validation score   (balanced_accuracy)
	0.66s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: ExtraTreesEntr_BAG_L2 ... Training model for up to 90.29s of the 90.28s of remaining time.
	Fitting 1 model on all data (use_child_oof=True) | Fitting with cpus=8, gpus=0
	0.7588	 = Validation score   (balanced_accuracy)
	0.63s	 = Training   runtime
	0.13s	 = Validation runtime
Fitting model: NeuralNetFastAI_BAG_L2 ... Training model for up to 89.50s of the 89.49s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.00%)
	0.7489	 = Validation score   (balanced_accuracy)
	5.7s	 = Training   runtime
	0.13s	 = Validation runtime
Fitting model: XGBoost_BAG_L2 ... Training model for up to 80.62s of the 80.61s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.06%)
	0.7971	 = Validation score   (balanced_accuracy)
	2.57s	 = Training   runtime
	0.08s	 = Validation runtime
Fitting model: NeuralNetTorch_BAG_L2 ... Training model for up to 74.10s of the 74.10s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.00%)
	0.8126	 = Validation score   (balanced_accuracy)
	26.72s	 = Training   runtime
	0.16s	 = Validation runtime
Fitting model: LightGBMLarge_BAG_L2 ... Training model for up to 44.07s of the 44.06s of remaining time.
	Fitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy (8 workers, per: cpus=1, gpus=0, memory=0.29%)
	0.7834	 = Validation score   (balanced_accuracy)
	4.12s	 = Training   runtime
	0.07s	 = Validation runtime
Fitting model: WeightedEnsemble_L3 ... Training model for up to 179.91s of the 36.57s of remaining time.
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/26.9 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L2': 0.538, 'XGBoost_BAG_L2': 0.308, 'NeuralNetTorch_BAG_L1': 0.077, 'ExtraTreesEntr_BAG_L2': 0.077}
	0.8151	 = Validation score   (balanced_accuracy)
	0.25s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 143.7s ... Best model: WeightedEnsemble_L3 | Estimated inference throughput: 160.7 rows/s (125 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass")
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 NeuralNetTorch_BAG_L2 0.786935 0.812563 balanced_accuracy 4.646739 1.109388 61.752962 0.640417 0.161899 26.718968 2 True 22
1 WeightedEnsemble_L3 0.775800 0.815083 balanced_accuracy 5.033798 1.314348 65.195856 0.002993 0.000917 0.248854 3 True 24
2 NeuralNetTorch_BAG_L1 0.767081 0.809518 balanced_accuracy 0.450947 0.120358 14.543094 0.450947 0.120358 14.543094 1 True 10
3 WeightedEnsemble_L2 0.767081 0.809518 balanced_accuracy 0.453117 0.121753 14.781031 0.002170 0.001396 0.237937 2 True 12
4 CatBoost_BAG_L2 0.759693 0.793914 balanced_accuracy 4.054651 0.996059 41.823847 0.048330 0.048570 6.789853 2 True 17
5 LightGBMXT_BAG_L1 0.758639 0.778835 balanced_accuracy 0.524297 0.060408 1.634013 0.524297 0.060408 1.634013 1 True 1
6 CatBoost_BAG_L1 0.755940 0.768712 balanced_accuracy 0.096915 0.042109 7.078955 0.096915 0.042109 7.078955 1 True 5
7 XGBoost_BAG_L2 0.754565 0.797148 balanced_accuracy 4.276884 1.023809 37.601204 0.270563 0.076320 2.567210 2 True 21
8 LightGBM_BAG_L2 0.751866 0.801663 balanced_accuracy 4.272264 1.002546 37.509482 0.265942 0.055056 2.475489 2 True 14
9 ExtraTreesEntr_BAG_L2 0.749119 0.758758 balanced_accuracy 4.119826 1.075213 35.660824 0.113504 0.127723 0.626830 2 True 19
10 LightGBMXT_BAG_L2 0.748702 0.786227 balanced_accuracy 4.246110 1.005201 36.667126 0.239789 0.057711 1.633132 2 True 13
11 ExtraTreesGini_BAG_L2 0.748127 0.748257 balanced_accuracy 4.123268 1.072443 35.692085 0.116947 0.124954 0.658091 2 True 18
12 XGBoost_BAG_L1 0.747715 0.786227 balanced_accuracy 0.484042 0.077999 1.575752 0.484042 0.077999 1.575752 1 True 9
13 RandomForestEntr_BAG_L2 0.746406 0.776336 balanced_accuracy 4.116608 1.068860 35.682957 0.110287 0.121370 0.648963 2 True 16
14 RandomForestGini_BAG_L2 0.746051 0.765961 balanced_accuracy 4.118439 1.069075 35.697951 0.112118 0.121586 0.663958 2 True 15
15 LightGBMLarge_BAG_L2 0.745884 0.783413 balanced_accuracy 4.497786 1.017043 39.151877 0.491464 0.069554 4.117883 2 True 23
16 RandomForestGini_BAG_L1 0.745745 0.748194 balanced_accuracy 0.118708 0.121088 0.708103 0.118708 0.121088 0.708103 1 True 3
17 LightGBM_BAG_L1 0.745232 0.779696 balanced_accuracy 0.297431 0.053448 1.778871 0.297431 0.053448 1.778871 1 True 2
18 RandomForestEntr_BAG_L1 0.743132 0.756678 balanced_accuracy 0.100914 0.121509 0.635930 0.100914 0.121509 0.635930 1 True 4
19 NeuralNetFastAI_BAG_L2 0.741699 0.748866 balanced_accuracy 5.185719 1.076216 40.729774 1.179397 0.128727 5.695780 2 True 20
20 LightGBMLarge_BAG_L1 0.735861 0.747333 balanced_accuracy 0.472554 0.060283 2.592734 0.472554 0.060283 2.592734 1 True 11
21 ExtraTreesGini_BAG_L1 0.731460 0.751491 balanced_accuracy 0.118771 0.120046 0.646260 0.118771 0.120046 0.646260 1 True 6
22 NeuralNetFastAI_BAG_L1 0.728689 0.725071 balanced_accuracy 1.701302 0.108622 5.791547 1.701302 0.108622 5.791547 1 True 8
23 ExtraTreesEntr_BAG_L1 0.725517 0.730553 balanced_accuracy 0.112994 0.121904 0.641470 0.112994 0.121904 0.641470 1 True 7

You should not provide tuning_data when stacking/bagging, and instead provide all your available data as train_data (which AutoGluon will split in more intellgent ways). num_bag_sets controls how many times the k-fold bagging process is repeated to further reduce variance (increasing this may further boost accuracy but will substantially increase training times, inference latency, and memory/disk usage). Rather than manually searching for good bagging/stacking values yourself, AutoGluon will automatically select good values for you if you specify auto_stack instead (which is used in the best_quality preset):

Decision Threshold Calibration

Major metric score improvements can be achieved in binary classification for metrics such as "f1" and "balanced_accuracy" by adjusting the prediction decision threshold via calibrate_decision_threshold to a value other than the default 0.5.

Below is an example of the "balanced_accuracy" score achieved on the test data with and without calibrating the decision threshold:

print(f'Prior to calibration (predictor.decision_threshold={predictor.decision_threshold}):')
scores = predictor.evaluate(test_data)

calibrated_decision_threshold = predictor.calibrate_decision_threshold()
predictor.set_decision_threshold(calibrated_decision_threshold)

print(f'After calibration (predictor.decision_threshold={predictor.decision_threshold}):')
scores_calibrated = predictor.evaluate(test_data)
Prior to calibration (predictor.decision_threshold=0.5):
After calibration (predictor.decision_threshold=0.381):
Calibrating decision threshold to optimize metric balanced_accuracy | Checking 51 thresholds...
Calibrating decision threshold via fine-grained search | Checking 38 thresholds...
	Base Threshold: 0.500	| val: 0.8151
	Best Threshold: 0.381	| val: 0.8260
Updating predictor.decision_threshold from 0.5 -> 0.381
	This will impact how prediction probabilities are converted to predictions in binary classification.
	Prediction probabilities of the positive class >0.381 will be predicted as the positive class ( >50K). This can significantly impact metric scores.
	You can update this value via `predictor.set_decision_threshold`.
	You can calculate an optimal decision threshold on the validation data via `predictor.calibrate_decision_threshold()`.
for metric_name in scores:
    metric_score = scores[metric_name]
    metric_score_calibrated = scores_calibrated[metric_name]
    decision_threshold = predictor.decision_threshold
    print(f'decision_threshold={decision_threshold:.3f}\t| metric="{metric_name}"'
          f'\n\ttest_score uncalibrated: {metric_score:.4f}'
          f'\n\ttest_score   calibrated: {metric_score_calibrated:.4f}'
          f'\n\ttest_score        delta: {metric_score_calibrated-metric_score:.4f}')
decision_threshold=0.381	| metric="balanced_accuracy"
	test_score uncalibrated: 0.7758
	test_score   calibrated: 0.7987
	test_score        delta: 0.0229
decision_threshold=0.381	| metric="accuracy"
	test_score uncalibrated: 0.8504
	test_score   calibrated: 0.8373
	test_score        delta: -0.0131
decision_threshold=0.381	| metric="mcc"
	test_score uncalibrated: 0.5731
	test_score   calibrated: 0.5728
	test_score        delta: -0.0003
decision_threshold=0.381	| metric="roc_auc"
	test_score uncalibrated: 0.9037
	test_score   calibrated: 0.9037
	test_score        delta: 0.0000
decision_threshold=0.381	| metric="f1"
	test_score uncalibrated: 0.6679
	test_score   calibrated: 0.6791
	test_score        delta: 0.0112
decision_threshold=0.381	| metric="precision"
	test_score uncalibrated: 0.7059
	test_score   calibrated: 0.6384
	test_score        delta: -0.0675
decision_threshold=0.381	| metric="recall"
	test_score uncalibrated: 0.6337
	test_score   calibrated: 0.7252
	test_score        delta: 0.0915

Notice that calibrating for “balanced_accuracy” majorly improved the “balanced_accuracy” metric score, but it harmed the “accuracy” score. Threshold calibration will often result in a tradeoff between performance on different metrics, and the user should keep this in mind.

Instead of calibrating for “balanced_accuracy” specifically, we can calibrate for any metric if we want to maximize the score of that metric:

predictor.set_decision_threshold(0.5)  # Reset decision threshold
for metric_name in ['f1', 'balanced_accuracy', 'mcc']:
    metric_score = predictor.evaluate(test_data, silent=True)[metric_name]
    calibrated_decision_threshold = predictor.calibrate_decision_threshold(metric=metric_name, verbose=False)
    metric_score_calibrated = predictor.evaluate(
        test_data, decision_threshold=calibrated_decision_threshold, silent=True
    )[metric_name]
    print(f'decision_threshold={calibrated_decision_threshold:.3f}\t| metric="{metric_name}"'
          f'\n\ttest_score uncalibrated: {metric_score:.4f}'
          f'\n\ttest_score   calibrated: {metric_score_calibrated:.4f}'
          f'\n\ttest_score        delta: {metric_score_calibrated-metric_score:.4f}')
decision_threshold=0.500	| metric="f1"
	test_score uncalibrated: 0.6679
	test_score   calibrated: 0.6679
	test_score        delta: 0.0000
decision_threshold=0.381	| metric="balanced_accuracy"
	test_score uncalibrated: 0.7758
	test_score   calibrated: 0.7987
	test_score        delta: 0.0229
decision_threshold=0.500	| metric="mcc"
	test_score uncalibrated: 0.5731
	test_score   calibrated: 0.5731
	test_score        delta: 0.0000
Updating predictor.decision_threshold from 0.381 -> 0.5
	This will impact how prediction probabilities are converted to predictions in binary classification.
	Prediction probabilities of the positive class >0.5 will be predicted as the positive class ( >50K). This can significantly impact metric scores.
	You can update this value via `predictor.set_decision_threshold`.
	You can calculate an optimal decision threshold on the validation data via `predictor.calibrate_decision_threshold()`.

Instead of calibrating the decision threshold post-fit, you can have it automatically occur during the fit call by specifying the fit parameter predictor.fit(..., calibrate_decision_threshold=True).

Luckily, AutoGluon will automatically apply decision threshold calibration when beneficial, as the default value is calibrate_decision_threshold="auto". We recommend keeping this value as the default in most cases.

Additional usage examples are below:

# Will use the decision_threshold specified in `predictor.decision_threshold`, can be set via `predictor.set_decision_threshold`
y_pred = predictor.predict(test_data)
y_pred_08 = predictor.predict(test_data, decision_threshold=0.8)  # Specify a specific threshold to use only for this predict

y_pred_proba = predictor.predict_proba(test_data)
y_pred = predictor.predict_from_proba(y_pred_proba)  # Identical output to calling .predict(test_data)
y_pred_08 = predictor.predict_from_proba(y_pred_proba, decision_threshold=0.8)  # Identical output to calling .predict(test_data, decision_threshold=0.8)

Prediction options (inference)

Even if you’ve started a new Python session since last calling fit(), you can still load a previously trained predictor from disk:

predictor = TabularPredictor.load(save_path)  # `predictor.path` is another way to get the relative path needed to later load the predictor.

Above save_path is the same folder previously passed to TabularPredictor, in which all the trained models have been saved. You can train easily models on one machine and deploy them on another. Simply copy the save_path folder to the new machine and specify its new path in TabularPredictor.load().

To find out the required feature columns to make predictions, call predictor.features():

predictor.features()
['age',
 'workclass',
 'fnlwgt',
 'education',
 'education-num',
 'marital-status',
 'occupation',
 'relationship',
 'race',
 'sex',
 'capital-gain',
 'capital-loss',
 'hours-per-week',
 'native-country']

We can make a prediction on an individual example rather than a full dataset:

datapoint = X_test.iloc[[0]]  # Note: .iloc[0] won't work because it returns pandas Series instead of DataFrame
datapoint
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country
0 31 Private 169085 11th 7 Married-civ-spouse Sales Wife White Female 0 0 20 United-States
predictor.predict(datapoint)
0     <=50K
Name: class, dtype: object

To output predicted class probabilities instead of predicted classes, you can use:

predictor.predict_proba(datapoint)  # returns a DataFrame that shows which probability corresponds to which class
<=50K >50K
0 0.948482 0.051518

By default, predict() and predict_proba() will utilize the model that AutoGluon thinks is most accurate, which is usually an ensemble of many individual models. Here’s how to see which model this is:

predictor.model_best
'WeightedEnsemble_L3'

We can instead specify a particular model to use for predictions (e.g. to reduce inference latency). Note that a ‘model’ in AutoGluon may refer to, for example, a single Neural Network, a bagged ensemble of many Neural Network copies trained on different training/validation splits, a weighted ensemble that aggregates the predictions of many other models, or a stacker model that operates on predictions output by other models. This is akin to viewing a Random Forest as one ‘model’ when it is in fact an ensemble of many decision trees.

Before deciding which model to use, let’s evaluate all of the models AutoGluon has previously trained on our test data:

predictor.leaderboard(test_data)
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 NeuralNetTorch_BAG_L2 0.786935 0.812563 balanced_accuracy 3.392177 1.109388 61.752962 0.591998 0.161899 26.718968 2 True 22
1 WeightedEnsemble_L3 0.775800 0.815083 balanced_accuracy 3.742407 1.314348 65.195856 0.003925 0.000917 0.248854 3 True 24
2 NeuralNetTorch_BAG_L1 0.767081 0.809518 balanced_accuracy 0.397695 0.120358 14.543094 0.397695 0.120358 14.543094 1 True 10
3 WeightedEnsemble_L2 0.767081 0.809518 balanced_accuracy 0.399440 0.121753 14.781031 0.001745 0.001396 0.237937 2 True 12
4 CatBoost_BAG_L2 0.759693 0.793914 balanced_accuracy 2.843507 0.996059 41.823847 0.043328 0.048570 6.789853 2 True 17
5 LightGBMXT_BAG_L1 0.758639 0.778835 balanced_accuracy 0.365914 0.060408 1.634013 0.365914 0.060408 1.634013 1 True 1
6 CatBoost_BAG_L1 0.755940 0.768712 balanced_accuracy 0.048838 0.042109 7.078955 0.048838 0.042109 7.078955 1 True 5
7 XGBoost_BAG_L2 0.754565 0.797148 balanced_accuracy 3.043562 1.023809 37.601204 0.243383 0.076320 2.567210 2 True 21
8 LightGBM_BAG_L2 0.751866 0.801663 balanced_accuracy 3.023854 1.002546 37.509482 0.223675 0.055056 2.475489 2 True 14
9 ExtraTreesEntr_BAG_L2 0.749119 0.758758 balanced_accuracy 2.903102 1.075213 35.660824 0.102923 0.127723 0.626830 2 True 19
10 LightGBMXT_BAG_L2 0.748702 0.786227 balanced_accuracy 2.999880 1.005201 36.667126 0.199700 0.057711 1.633132 2 True 13
11 ExtraTreesGini_BAG_L2 0.748127 0.748257 balanced_accuracy 2.903032 1.072443 35.692085 0.102853 0.124954 0.658091 2 True 18
12 XGBoost_BAG_L1 0.747715 0.786227 balanced_accuracy 0.266736 0.077999 1.575752 0.266736 0.077999 1.575752 1 True 9
13 RandomForestEntr_BAG_L2 0.746406 0.776336 balanced_accuracy 2.900270 1.068860 35.682957 0.100091 0.121370 0.648963 2 True 16
14 RandomForestGini_BAG_L2 0.746051 0.765961 balanced_accuracy 2.900634 1.069075 35.697951 0.100455 0.121586 0.663958 2 True 15
15 LightGBMLarge_BAG_L2 0.745884 0.783413 balanced_accuracy 3.203288 1.017043 39.151877 0.403109 0.069554 4.117883 2 True 23
16 RandomForestGini_BAG_L1 0.745745 0.748194 balanced_accuracy 0.100837 0.121088 0.708103 0.100837 0.121088 0.708103 1 True 3
17 LightGBM_BAG_L1 0.745232 0.779696 balanced_accuracy 0.274780 0.053448 1.778871 0.274780 0.053448 1.778871 1 True 2
18 RandomForestEntr_BAG_L1 0.743132 0.756678 balanced_accuracy 0.099790 0.121509 0.635930 0.099790 0.121509 0.635930 1 True 4
19 NeuralNetFastAI_BAG_L2 0.741699 0.748866 balanced_accuracy 3.889222 1.076216 40.729774 1.089043 0.128727 5.695780 2 True 20
20 LightGBMLarge_BAG_L1 0.735861 0.747333 balanced_accuracy 0.369758 0.060283 2.592734 0.369758 0.060283 2.592734 1 True 11
21 ExtraTreesGini_BAG_L1 0.731460 0.751491 balanced_accuracy 0.103992 0.120046 0.646260 0.103992 0.120046 0.646260 1 True 6
22 NeuralNetFastAI_BAG_L1 0.728689 0.725071 balanced_accuracy 1.037741 0.108622 5.791547 1.037741 0.108622 5.791547 1 True 8
23 ExtraTreesEntr_BAG_L1 0.725517 0.730553 balanced_accuracy 0.103856 0.121904 0.641470 0.103856 0.121904 0.641470 1 True 7

The leaderboard shows each model’s predictive performance on the test data (score_test) and validation data (score_val), as well as the time required to: produce predictions for the test data (pred_time_val), produce predictions on the validation data (pred_time_val), and train only this model (fit_time). Below, we show that a leaderboard can be produced without new data (just uses the data previously reserved for validation inside fit) and can display extra information about each model:

predictor.leaderboard(extra_info=True)
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order ... hyperparameters hyperparameters_fit ag_args_fit features compile_time child_hyperparameters child_hyperparameters_fit child_ag_args_fit ancestors descendants
0 WeightedEnsemble_L3 0.815083 balanced_accuracy 1.314348 65.195856 0.000917 0.248854 3 True 24 ... {'use_orig_features': False, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [NeuralNetTorch_BAG_L1, XGBoost_BAG_L2, NeuralNetTorch_BAG_L2, ExtraTreesEntr_BAG_L2] None {'ensemble_size': 25, 'subsample_size': 1000000} {'ensemble_size': 13} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [CatBoost_BAG_L1, ExtraTreesEntr_BAG_L2, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetTorch_BAG_L2, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, XGBoost_BAG_L2, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
1 NeuralNetTorch_BAG_L2 0.812563 balanced_accuracy 1.109388 61.752962 0.161899 26.718968 2 True 22 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'num_epochs': 1000, 'epochs_wo_improve': None, 'activation': 'relu', 'embedding_size_factor': 1.0, 'embed_exponent': 0.56, 'max_embedding_dim': 100, 'y_range': None, 'y_range_extend': 0.05, 'dropout_prob': 0.1, 'optimizer': 'adam', 'learning_rate': 0.0003, 'weight_decay': 1e-06, 'proc.embed_min_categories': 4, 'proc.impute_strategy': 'median', 'proc.max_category_levels': 100, 'proc.skew_threshold': 0.99, 'use_ngram_features': False, 'num_layers': 4, 'hidden_size': 128, 'max_batch_size': 512, 'use_batchnorm': False, 'loss_function': 'auto', 'seed_value': 0} {'batch_size': 32, 'num_epochs': 36} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': ['text_ngram', 'text_as_category'], 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] [WeightedEnsemble_L3]
2 NeuralNetTorch_BAG_L1 0.809518 balanced_accuracy 0.120358 14.543094 0.120358 14.543094 1 True 10 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'num_epochs': 1000, 'epochs_wo_improve': None, 'activation': 'relu', 'embedding_size_factor': 1.0, 'embed_exponent': 0.56, 'max_embedding_dim': 100, 'y_range': None, 'y_range_extend': 0.05, 'dropout_prob': 0.1, 'optimizer': 'adam', 'learning_rate': 0.0003, 'weight_decay': 1e-06, 'proc.embed_min_categories': 4, 'proc.impute_strategy': 'median', 'proc.max_category_levels': 100, 'proc.skew_threshold': 0.99, 'use_ngram_features': False, 'num_layers': 4, 'hidden_size': 128, 'max_batch_size': 512, 'use_batchnorm': False, 'loss_function': 'auto', 'seed_value': 0} {'batch_size': 32, 'num_epochs': 14} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': ['text_ngram', 'text_as_category'], 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, WeightedEnsemble_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
3 WeightedEnsemble_L2 0.809518 balanced_accuracy 0.121753 14.781031 0.001396 0.237937 2 True 12 ... {'use_orig_features': False, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [NeuralNetTorch_BAG_L1] None {'ensemble_size': 25, 'subsample_size': 1000000} {'ensemble_size': 1} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [NeuralNetTorch_BAG_L1] []
4 LightGBM_BAG_L2 0.801663 balanced_accuracy 1.002546 37.509482 0.055056 2.475489 2 True 14 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'learning_rate': 0.05, 'seed': 0} {'num_boost_round': 111} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
5 XGBoost_BAG_L2 0.797148 balanced_accuracy 1.023809 37.601204 0.076320 2.567210 2 True 21 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'n_estimators': 10000, 'learning_rate': 0.1, 'n_jobs': -1, 'proc.max_category_levels': 100, 'objective': 'binary:logistic', 'booster': 'gbtree', 'seed': 0} {'n_estimators': 41} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] [WeightedEnsemble_L3]
6 CatBoost_BAG_L2 0.793914 balanced_accuracy 0.996059 41.823847 0.048570 6.789853 2 True 17 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'iterations': 10000, 'learning_rate': 0.05, 'allow_writing_files': False, 'eval_metric': 'BalancedAccuracy', 'random_seed': 0} {'iterations': 69} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
7 XGBoost_BAG_L1 0.786227 balanced_accuracy 0.077999 1.575752 0.077999 1.575752 1 True 9 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'n_estimators': 10000, 'learning_rate': 0.1, 'n_jobs': -1, 'proc.max_category_levels': 100, 'objective': 'binary:logistic', 'booster': 'gbtree', 'seed': 0} {'n_estimators': 87} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
8 LightGBMXT_BAG_L2 0.786227 balanced_accuracy 1.005201 36.667126 0.057711 1.633132 2 True 13 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'learning_rate': 0.05, 'extra_trees': True, 'seed': 0} {'num_boost_round': 112} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
9 LightGBMLarge_BAG_L2 0.783413 balanced_accuracy 1.017043 39.151877 0.069554 4.117883 2 True 23 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'seed': 0} {'num_boost_round': 164} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
10 LightGBM_BAG_L1 0.779696 balanced_accuracy 0.053448 1.778871 0.053448 1.778871 1 True 2 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'learning_rate': 0.05, 'seed': 0} {'num_boost_round': 123} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
11 LightGBMXT_BAG_L1 0.778835 balanced_accuracy 0.060408 1.634013 0.060408 1.634013 1 True 1 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'learning_rate': 0.05, 'extra_trees': True, 'seed': 0} {'num_boost_round': 226} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
12 RandomForestEntr_BAG_L2 0.776336 balanced_accuracy 1.068860 35.682957 0.121370 0.648963 2 True 16 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'entropy', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
13 CatBoost_BAG_L1 0.768712 balanced_accuracy 0.042109 7.078955 0.042109 7.078955 1 True 5 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'iterations': 10000, 'learning_rate': 0.05, 'allow_writing_files': False, 'eval_metric': 'BalancedAccuracy', 'random_seed': 0} {'iterations': 133} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
14 RandomForestGini_BAG_L2 0.765961 balanced_accuracy 1.069075 35.697951 0.121586 0.663958 2 True 15 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'gini', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
15 ExtraTreesEntr_BAG_L2 0.758758 balanced_accuracy 1.075213 35.660824 0.127723 0.626830 2 True 19 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'entropy', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] [WeightedEnsemble_L3]
16 RandomForestEntr_BAG_L1 0.756678 balanced_accuracy 0.121509 0.635930 0.121509 0.635930 1 True 4 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'entropy', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
17 ExtraTreesGini_BAG_L1 0.751491 balanced_accuracy 0.120046 0.646260 0.120046 0.646260 1 True 6 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'gini', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
18 NeuralNetFastAI_BAG_L2 0.748866 balanced_accuracy 1.076216 40.729774 0.128727 5.695780 2 True 20 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'layers': None, 'emb_drop': 0.1, 'ps': 0.1, 'bs': 'auto', 'lr': 0.01, 'epochs': 'auto', 'early.stopping.min_delta': 0.0001, 'early.stopping.patience': 20, 'smoothing': 0.0, 'random_seed': 0} {'epochs': 30, 'best_epoch': 9} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': ['text_ngram', 'text_as_category'], 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
19 ExtraTreesGini_BAG_L2 0.748257 balanced_accuracy 1.072443 35.692085 0.124954 0.658091 2 True 18 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [capital-gain, LightGBM_BAG_L1, marital-status, age, education-num, CatBoost_BAG_L1, sex, fnlwgt, ExtraTreesEntr_BAG_L1, occupation, NeuralNetFastAI_BAG_L1, race, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1, workclass, education, NeuralNetTorch_BAG_L1, native-country, capital-loss, LightGBMXT_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, hours-per-week, relationship] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'gini', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [CatBoost_BAG_L1, NeuralNetTorch_BAG_L1, ExtraTreesEntr_BAG_L1, LightGBM_BAG_L1, LightGBMXT_BAG_L1, NeuralNetFastAI_BAG_L1, XGBoost_BAG_L1, RandomForestGini_BAG_L1, ExtraTreesGini_BAG_L1, RandomForestEntr_BAG_L1] []
20 RandomForestGini_BAG_L1 0.748194 balanced_accuracy 0.121088 0.708103 0.121088 0.708103 1 True 3 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'gini', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
21 LightGBMLarge_BAG_L1 0.747333 balanced_accuracy 0.060283 2.592734 0.060283 2.592734 1 True 11 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'seed': 0} {'num_boost_round': 148} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] []
22 ExtraTreesEntr_BAG_L1 0.730553 balanced_accuracy 0.121904 0.641470 0.121904 0.641470 1 True 7 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'use_child_oof': True, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'n_estimators': 300, 'max_leaf_nodes': 15000, 'n_jobs': -1, 'bootstrap': True, 'criterion': 'entropy', 'random_state': 0} {'n_estimators': 300} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]
23 NeuralNetFastAI_BAG_L1 0.725071 balanced_accuracy 0.108622 5.791547 0.108622 5.791547 1 True 8 ... {'use_orig_features': True, 'valid_stacker': True, 'max_base_models': 0, 'max_base_models_per_type': 'auto', 'save_bag_folds': True, 'stratify': 'auto', 'bin': 'auto', 'n_bins': None, 'vary_seed_across_folds': False, 'model_random_seed': 0} {} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': None, 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None, 'drop_unique': False} [age, education, education-num, sex, fnlwgt, relationship, occupation, capital-gain, capital-loss, native-country, race, workclass, hours-per-week, marital-status] None {'layers': None, 'emb_drop': 0.1, 'ps': 0.1, 'bs': 'auto', 'lr': 0.01, 'epochs': 'auto', 'early.stopping.min_delta': 0.0001, 'early.stopping.patience': 20, 'smoothing': 0.0, 'random_seed': 0} {'epochs': 30, 'best_epoch': 9} {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['bool', 'int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': ['text_ngram', 'text_as_category'], 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None} [] [NeuralNetFastAI_BAG_L2, ExtraTreesGini_BAG_L2, ExtraTreesEntr_BAG_L2, RandomForestEntr_BAG_L2, LightGBMXT_BAG_L2, WeightedEnsemble_L3, RandomForestGini_BAG_L2, NeuralNetTorch_BAG_L2, LightGBM_BAG_L2, CatBoost_BAG_L2, LightGBMLarge_BAG_L2, XGBoost_BAG_L2]

24 rows × 32 columns

The expanded leaderboard shows properties like how many features are used by each model (num_features), which other models are ancestors whose predictions are required inputs for each model (ancestors), and how much memory each model and all its ancestors would occupy if simultaneously persisted (memory_size_w_ancestors). See the leaderboard documentation for full details.

To show scores for other metrics, you can specify the extra_metrics argument when passing in test_data:

predictor.leaderboard(test_data, extra_metrics=['accuracy', 'balanced_accuracy', 'log_loss'])
model score_test accuracy balanced_accuracy log_loss score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 NeuralNetTorch_BAG_L2 0.786935 0.843177 0.786935 -0.331951 0.812563 balanced_accuracy 3.444195 1.109388 61.752962 0.600324 0.161899 26.718968 2 True 22
1 WeightedEnsemble_L3 0.775800 0.850445 0.775800 -0.324749 0.815083 balanced_accuracy 3.799388 1.314348 65.195856 0.002771 0.000917 0.248854 3 True 24
2 NeuralNetTorch_BAG_L1 0.767081 0.836012 0.767081 -0.349641 0.809518 balanced_accuracy 0.403701 0.120358 14.543094 0.403701 0.120358 14.543094 1 True 10
3 WeightedEnsemble_L2 0.767081 0.836012 0.767081 -0.349641 0.809518 balanced_accuracy 0.405428 0.121753 14.781031 0.001726 0.001396 0.237937 2 True 12
4 CatBoost_BAG_L2 0.759693 0.857611 0.759693 -0.385308 0.793914 balanced_accuracy 2.891997 0.996059 41.823847 0.048126 0.048570 6.789853 2 True 17
5 LightGBMXT_BAG_L1 0.758639 0.854642 0.758639 -0.320179 0.778835 balanced_accuracy 0.364142 0.060408 1.634013 0.364142 0.060408 1.634013 1 True 1
6 CatBoost_BAG_L1 0.755940 0.860272 0.755940 -0.403003 0.768712 balanced_accuracy 0.047042 0.042109 7.078955 0.047042 0.042109 7.078955 1 True 5
7 XGBoost_BAG_L2 0.754565 0.850241 0.754565 -0.339983 0.797148 balanced_accuracy 3.093371 1.023809 37.601204 0.249500 0.076320 2.567210 2 True 21
8 LightGBM_BAG_L2 0.751866 0.849524 0.751866 -0.349900 0.801663 balanced_accuracy 3.072136 1.002546 37.509482 0.228265 0.055056 2.475489 2 True 14
9 ExtraTreesEntr_BAG_L2 0.749119 0.850548 0.749119 -0.328483 0.758758 balanced_accuracy 2.946793 1.075213 35.660824 0.102922 0.127723 0.626830 2 True 19
10 LightGBMXT_BAG_L2 0.748702 0.853311 0.748702 -0.326402 0.786227 balanced_accuracy 3.040785 1.005201 36.667126 0.196914 0.057711 1.633132 2 True 13
11 ExtraTreesGini_BAG_L2 0.748127 0.851981 0.748127 -0.335611 0.748257 balanced_accuracy 2.946531 1.072443 35.692085 0.102660 0.124954 0.658091 2 True 18
12 XGBoost_BAG_L1 0.747715 0.850445 0.747715 -0.324872 0.786227 balanced_accuracy 0.257175 0.077999 1.575752 0.257175 0.077999 1.575752 1 True 9
13 RandomForestEntr_BAG_L2 0.746406 0.850036 0.746406 -0.342309 0.776336 balanced_accuracy 2.944619 1.068860 35.682957 0.100748 0.121370 0.648963 2 True 16
14 RandomForestGini_BAG_L2 0.746051 0.850855 0.746051 -0.343055 0.765961 balanced_accuracy 2.944460 1.069075 35.697951 0.100589 0.121586 0.663958 2 True 15
15 LightGBMLarge_BAG_L2 0.745884 0.842666 0.745884 -0.397783 0.783413 balanced_accuracy 3.243814 1.017043 39.151877 0.399943 0.069554 4.117883 2 True 23
16 RandomForestGini_BAG_L1 0.745745 0.843587 0.745745 -0.339655 0.748194 balanced_accuracy 0.132324 0.121088 0.708103 0.132324 0.121088 0.708103 1 True 3
17 LightGBM_BAG_L1 0.745232 0.846658 0.745232 -0.335439 0.779696 balanced_accuracy 0.275485 0.053448 1.778871 0.275485 0.053448 1.778871 1 True 2
18 RandomForestEntr_BAG_L1 0.743132 0.841642 0.743132 -0.342622 0.756678 balanced_accuracy 0.101026 0.121509 0.635930 0.101026 0.121509 0.635930 1 True 4
19 NeuralNetFastAI_BAG_L2 0.741699 0.848296 0.741699 -0.330586 0.748866 balanced_accuracy 3.913914 1.076216 40.729774 1.070043 0.128727 5.695780 2 True 20
20 LightGBMLarge_BAG_L1 0.735861 0.836217 0.735861 -0.372562 0.747333 balanced_accuracy 0.371449 0.060283 2.592734 0.371449 0.060283 2.592734 1 True 11
21 ExtraTreesGini_BAG_L1 0.731460 0.835398 0.731460 -0.354638 0.751491 balanced_accuracy 0.103176 0.120046 0.646260 0.103176 0.120046 0.646260 1 True 6
22 NeuralNetFastAI_BAG_L1 0.728689 0.845225 0.728689 -0.344521 0.725071 balanced_accuracy 1.056540 0.108622 5.791547 1.056540 0.108622 5.791547 1 True 8
23 ExtraTreesEntr_BAG_L1 0.725517 0.832224 0.725517 -0.351604 0.730553 balanced_accuracy 0.103258 0.121904 0.641470 0.103258 0.121904 0.641470 1 True 7

Notice that log_loss scores are negative. This is because metrics in AutoGluon are always shown in higher_is_better form. This means that metrics such as log_loss and root_mean_squared_error will have their signs FLIPPED, and values will be negative. This is necessary to avoid the user needing to know the metric to understand if higher is better when looking at leaderboard.

Here’s how to specify a particular model to use for prediction instead of AutoGluon’s default model-choice:

i = 0  # index of model to use
all_models = predictor.model_names()
model_to_use = all_models[i]
model_pred = predictor.predict(datapoint, model=model_to_use)
print("Prediction from %s model: %s" % (model_to_use, model_pred.iloc[0]))
Prediction from LightGBMXT_BAG_L1 model:  <=50K

We can easily access various information about the trained predictor or a particular model:

# Objects defined below are dicts of various information (not printed here as they are quite large):
predictor_information = predictor.info()  # access info about the predictor
model_info = predictor.model_info(model_to_use)  # access info about a model
model_info_alternative = predictor._trainer.load_model(model_to_use).get_info()  # load the inner model and access its info directly

The predictor also remembers what metric predictions should be evaluated with, which can be done with ground truth labels as follows:

y_pred_proba = predictor.predict_proba(X_test)
predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred_proba)
{'balanced_accuracy': np.float64(0.775799676668123),
 'accuracy': 0.8504452861091207,
 'mcc': 0.5731192798230831,
 'roc_auc': np.float64(0.9036868599903031),
 'f1': 0.6678790634235053,
 'precision': 0.7059106198942816,
 'recall': 0.6337359792924935}

Since the label columns remains in the test_data DataFrame, we can instead use the shorthand:

predictor.evaluate(test_data)
{'balanced_accuracy': np.float64(0.775799676668123),
 'accuracy': 0.8504452861091207,
 'mcc': 0.5731192798230831,
 'roc_auc': np.float64(0.9036868599903031),
 'f1': 0.6678790634235053,
 'precision': 0.7059106198942816,
 'recall': 0.6337359792924935}

Interpretability (feature importance)

To better understand our trained predictor, we can estimate the overall importance of each feature:

predictor.feature_importance(test_data)
Computing feature importance via permutation shuffling for 14 features using 5000 rows with 5 shuffle sets...
	193.51s	= Expected runtime (38.7s per shuffle set)
	119.75s	= Actual runtime (Completed 5 of 5 shuffle sets)
importance stddev p_value n p99_high p99_low
marital-status 0.079090 0.003301 3.633073e-07 5 0.085887 0.072293
capital-gain 0.032361 0.003726 2.072973e-05 5 0.040034 0.024689
education-num 0.029732 0.005270 1.136531e-04 5 0.040583 0.018881
occupation 0.028098 0.001764 1.852620e-06 5 0.031729 0.024467
relationship 0.022664 0.001490 2.225925e-06 5 0.025731 0.019597
age 0.019103 0.005119 5.637912e-04 5 0.029644 0.008562
hours-per-week 0.010759 0.003316 9.583770e-04 5 0.017588 0.003931
capital-loss 0.006583 0.002536 2.189878e-03 5 0.011804 0.001362
native-country 0.003769 0.001961 6.335805e-03 5 0.007807 -0.000269
race 0.001849 0.001259 1.518866e-02 5 0.004440 -0.000743
sex 0.000810 0.000901 5.742665e-02 5 0.002665 -0.001046
education 0.000187 0.005337 4.707353e-01 5 0.011176 -0.010803
workclass -0.000514 0.001805 7.206413e-01 5 0.003203 -0.004231
fnlwgt -0.000943 0.004098 6.830963e-01 5 0.007494 -0.009380

Computed via permutation-shuffling, these feature importance scores quantify the drop in predictive performance (of the already trained predictor) when one column’s values are randomly shuffled across rows. The top features in this list contribute most to AutoGluon’s accuracy (for predicting when/if a patient will be readmitted to the hospital). Features with non-positive importance score hardly contribute to the predictor’s accuracy, or may even be actively harmful to include in the data (consider removing these features from your data and calling fit again). These scores facilitate interpretability of the predictor’s global behavior (which features it relies on for all predictions). To get local explanations regarding which features influence a particular prediction, check out the example notebooks for explaining particular AutoGluon predictions using Shapely values.

Before making judgement on if AutoGluon is more or less interpretable than another solution, we recommend reading The Mythos of Model Interpretability by Zachary Lipton, which covers why often-claimed interpretable models such as trees and linear models are rarely meaningfully more interpretable than more advanced models.

Accelerating inference

We describe multiple ways to reduce the time it takes for AutoGluon to produce predictions.

Before providing code examples, it is important to understand that there are several ways to accelerate inference in AutoGluon. The table below lists the options in order of priority.

Optimization

Inference Speedup

Cost

Notes

refit_full

8x (requires bagging)

-Quality

Only provides speedup with bagging enabled.

persist

Up to 10x in online-inference

+Memory Usage

If memory is not sufficient to persist model, speedup is not gained. Speedup is most effective in online-inference and is not relevant in batch inference.

infer_limit

Up to 50x+

-Quality (Relative to speedup)

Best when combined with refit_full.

feature pruning

Typically at most 1.5x. More if willing to lower quality significantly.

-Quality?

Dependent on the existence of unimportant features in data. Call predictor.feature_importance(test_data) to gauge which features could be removed.

use faster hardware

Usually at most 3x. Depends on hardware (ignoring GPU).

+Hardware

As an example, an EC2 c6i.2xlarge is ~1.6x faster than an m5.2xlarge for a similar price. Laptops in particular might be slow compared to cloud instances.

manual hyperparameters adjustment

Usually at most 2x assuming infer_limit is already specified.

Manual Effort

Can be very complicated and is not recommended. Potential ways to get speedups this way is to reduce the number of trees in LightGBM, XGBoost, and CatBoost.

manual data preprocessing

Usually at most 1.2x.

Manual Effort

Only relevant for online-inference. Not recommended as AutoGluon’s default preprocessing is highly optimized.

For most scenarios, the order of inference optimizations should be:

  1. refit_full

  2. persist

  3. infer_limit

If following these recommendations does not lead to a sufficiently fast model, you may consider the more advanced options in the table.

persist: Keeping models in memory

By default, AutoGluon loads models into memory one at a time and only when they are needed for prediction. This strategy is robust for large stacked/bagged ensembles, but leads to slower prediction times. If you plan to repeatedly make predictions (e.g. on new datapoints one at a time rather than one large test dataset), you can first specify that all models required for inference should be loaded into memory as follows:

import time
import numpy as np

num_test = 20
preds = np.array(['']*num_test, dtype='object')
time_start = time.time()
for i in range(num_test):
    datapoint = X_test.iloc[[i]]
    pred_numpy = predictor.predict(datapoint, as_pandas=False)
    preds[i] = pred_numpy[0]
time_end = time.time()
time_without_persist = (time_end - time_start) / num_test
predictor.persist()

preds = np.array(['']*num_test, dtype='object')
time_start = time.time()
for i in range(num_test):
    datapoint = X_test.iloc[[i]]
    pred_numpy = predictor.predict(datapoint, as_pandas=False)
    preds[i] = pred_numpy[0]
time_end = time.time()
time_with_persist = (time_end - time_start) / num_test

predictor.unpersist()  # free memory by clearing models, future predict() calls will load models from disk
Persisting 14 models in memory. Models will require 0.25% of memory.
Unpersisted 14 models: ['CatBoost_BAG_L1', 'ExtraTreesEntr_BAG_L2', 'NeuralNetTorch_BAG_L1', 'ExtraTreesEntr_BAG_L1', 'LightGBM_BAG_L1', 'WeightedEnsemble_L3', 'LightGBMXT_BAG_L1', 'NeuralNetTorch_BAG_L2', 'NeuralNetFastAI_BAG_L1', 'XGBoost_BAG_L1', 'RandomForestGini_BAG_L1', 'XGBoost_BAG_L2', 'ExtraTreesGini_BAG_L1', 'RandomForestEntr_BAG_L1']
['CatBoost_BAG_L1',
 'ExtraTreesEntr_BAG_L2',
 'NeuralNetTorch_BAG_L1',
 'ExtraTreesEntr_BAG_L1',
 'LightGBM_BAG_L1',
 'WeightedEnsemble_L3',
 'LightGBMXT_BAG_L1',
 'NeuralNetTorch_BAG_L2',
 'NeuralNetFastAI_BAG_L1',
 'XGBoost_BAG_L1',
 'RandomForestGini_BAG_L1',
 'XGBoost_BAG_L2',
 'ExtraTreesGini_BAG_L1',
 'RandomForestEntr_BAG_L1']
print(f"Inference time unoptimized:  {time_without_persist:.3f}s")
print(f"Inference time with persist: {time_with_persist:.3f}s")
Inference time unoptimized:  0.882s
Inference time with persist: 0.604s

You can alternatively specify a particular model to persist via the models argument of persist(), or simply set models='all' to simultaneously load every single model that was trained during fit.

infer_limit: Inference speed as a fit constraint

If you know your latency constraint prior to fitting the predictor, you can specify it explicitly as a fit argument. AutoGluon will then automatically train models in a fashion that attempts to satisfy the constraint.

This constraint has two components: infer_limit and infer_limit_batch_size:

  • infer_limit is the time in seconds to predict 1 row of data. For example, infer_limit=0.05 means 50 ms per row of data, or 20 rows / second throughput.

  • infer_limit_batch_size is the amount of rows passed at once to predict when calculating per-row speed. This is very important because infer_limit_batch_size=1 (online-inference) is highly suboptimal as various operations have a fixed cost overhead regardless of data size. If you can pass your test data in bulk, you should specify infer_limit_batch_size=10000.

# At most 0.05 ms per row (20000 rows per second throughput)
infer_limit = 0.00005
# adhere to infer_limit with batches of size 10000 (batch-inference, easier to satisfy infer_limit)
infer_limit_batch_size = 10000
# adhere to infer_limit with batches of size 1 (online-inference, much harder to satisfy infer_limit)
# infer_limit_batch_size = 1  # Note that infer_limit<0.02 when infer_limit_batch_size=1 can be difficult to satisfy.
predictor_infer_limit = TabularPredictor(label=label, eval_metric=metric).fit(
    train_data=train_data,
    time_limit=30,
    infer_limit=infer_limit,
    infer_limit_batch_size=infer_limit_batch_size,
)

# NOTE: If bagging was enabled, it is important to call refit_full at this stage.
#  infer_limit assumes that the user will call refit_full after fit.
# predictor_infer_limit.refit_full()

# NOTE: To align with inference speed calculated during fit, models must be persisted.
predictor_infer_limit.persist()
# Below is an optimized version that only persists the minimum required models for prediction.
# predictor_infer_limit.persist('best')

predictor_infer_limit.leaderboard()
No path specified. Models will be saved in: "AutogluonModels/ag-20251219_142353"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.5.0b20251219
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.9.1+cu128
CUDA Version:       12.8
GPU Memory:         GPU 0: 14.57/14.57 GB
Total GPU Memory:   Free: 14.57 GB, Allocated: 0.00 GB, Total: 14.57 GB
GPU Count:          1
Memory Avail:       27.11 GB / 30.95 GB (87.6%)
Disk Space Avail:   202.42 GB / 255.99 GB (79.1%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='extreme'  : New in v1.5: The state-of-the-art for tabular data. Massively better than 'best' on datasets <100000 samples by using new Tabular Foundation Models (TFMs) meta-learned on https://tabarena.ai: TabPFNv2, TabICL, Mitra, TabDPT, and TabM. Absolute best accuracy. Requires a GPU. Recommended 64 GB CPU memory and 32+ GB GPU memory.
	presets='best'     : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='best_v150': New in v1.5: Better quality than 'best' and 5x+ faster to train. Give it a try!
	presets='high'     : Strong accuracy with fast inference speed.
	presets='high_v150': New in v1.5: Better quality than 'high' and 5x+ faster to train. Give it a try!
	presets='good'     : Good accuracy with very fast inference speed.
	presets='medium'   : Fast training time, ideal for initial prototyping.
Using hyperparameters preset: hyperparameters='default'
Beginning AutoGluon training ... Time limit = 30s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20251219_142353"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [' >50K', ' <=50K']
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
	Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    27765.10 MB
	Train Data (Original)  Memory Usage: 0.50 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
	1.709μs	= Feature Preprocessing Time (1 row | 10000 batch size)
		Feature Preprocessing requires 3.42% of the overall inference constraint (0.05ms)
		0.048ms inference time budget remaining for models...
Data preprocessing and feature engineering runtime = 0.27s ...
AutoGluon will gauge predictive performance using evaluation metric: 'balanced_accuracy'
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
}
Fitting 11 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT ... Training model for up to 29.73s of the 29.72s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
	0.7575	 = Validation score   (balanced_accuracy)
	0.71s	 = Training   runtime
	0.0s	 = Validation runtime
	2.589μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	2.589μs	 = Validation runtime (1 row | 10000 batch size)
	2.589μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	2.589μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: LightGBM ... Training model for up to 29.01s of the 29.01s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
	0.7314	 = Validation score   (balanced_accuracy)
	0.84s	 = Training   runtime
	0.0s	 = Validation runtime
	1.476μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	1.476μs	 = Validation runtime (1 row | 10000 batch size)
	1.476μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	1.476μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: RandomForestGini ... Training model for up to 28.15s of the 28.15s of remaining time.
	Fitting with cpus=8, gpus=0
	0.7314	 = Validation score   (balanced_accuracy)
	0.74s	 = Training   runtime
	0.06s	 = Validation runtime
	8.839μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	8.839μs	 = Validation runtime (1 row | 10000 batch size)
	8.839μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	8.839μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: RandomForestEntr ... Training model for up to 27.33s of the 27.33s of remaining time.
	Fitting with cpus=8, gpus=0
	0.7281	 = Validation score   (balanced_accuracy)
	0.74s	 = Training   runtime
	0.05s	 = Validation runtime
	8.765μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	8.765μs	 = Validation runtime (1 row | 10000 batch size)
	8.765μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	8.765μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: CatBoost ... Training model for up to 26.53s of the 26.53s of remaining time.
	Fitting with cpus=4, gpus=0
	0.7771	 = Validation score   (balanced_accuracy)
	2.22s	 = Training   runtime
	0.0s	 = Validation runtime
	0.904μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	0.904μs	 = Validation runtime (1 row | 10000 batch size)
	0.904μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	0.904μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: ExtraTreesGini ... Training model for up to 24.29s of the 24.29s of remaining time.
	Fitting with cpus=8, gpus=0
	0.6824	 = Validation score   (balanced_accuracy)
	0.74s	 = Training   runtime
	0.06s	 = Validation runtime
	8.786μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	8.786μs	 = Validation runtime (1 row | 10000 batch size)
	8.786μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	8.786μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: ExtraTreesEntr ... Training model for up to 23.47s of the 23.47s of remaining time.
	Fitting with cpus=8, gpus=0
	0.6922	 = Validation score   (balanced_accuracy)
	0.74s	 = Training   runtime
	0.05s	 = Validation runtime
	8.837μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	8.837μs	 = Validation runtime (1 row | 10000 batch size)
	8.837μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	8.837μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: NeuralNetFastAI ... Training model for up to 22.66s of the 22.66s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
Metric balanced_accuracy is not supported by this model - using log_loss instead
No improvement since epoch 7: early stopping
	0.751	 = Validation score   (balanced_accuracy)
	1.24s	 = Training   runtime
	0.01s	 = Validation runtime
	0.013ms	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	0.013ms	 = Validation runtime (1 row | 10000 batch size)
	0.013ms	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	0.013ms	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: XGBoost ... Training model for up to 21.40s of the 21.40s of remaining time.
	Fitting with cpus=4, gpus=0
	0.7446	 = Validation score   (balanced_accuracy)
	0.69s	 = Training   runtime
	0.01s	 = Validation runtime
	2.858μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	2.858μs	 = Validation runtime (1 row | 10000 batch size)
	2.858μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	2.858μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: NeuralNetTorch ... Training model for up to 20.69s of the 20.69s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
/home/ci/opt/venv/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py:975: FutureWarning: The parameter `force_int_remainder_cols` is deprecated and will be removed in 1.9. It has no effect. Leave it to its default value to avoid this warning.
  warnings.warn(
	0.799	 = Validation score   (balanced_accuracy)
	4.05s	 = Training   runtime
	0.01s	 = Validation runtime
	4.576μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	4.576μs	 = Validation runtime (1 row | 10000 batch size)
	4.576μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	4.576μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Fitting model: LightGBMLarge ... Training model for up to 16.62s of the 16.62s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.1/27.1 GB
	0.685	 = Validation score   (balanced_accuracy)
	1.25s	 = Training   runtime
	0.0s	 = Validation runtime
	2.033μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	2.033μs	 = Validation runtime (1 row | 10000 batch size)
	2.033μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	2.033μs	 = Validation runtime (1 row | 10000 batch size | REFIT)
Removing 3/11 base models to satisfy inference constraint (constraint=0.046ms) ...
	0.063ms	-> 0.054ms	(ExtraTreesGini)
	0.054ms	-> 0.052ms	(LightGBMLarge)
	0.052ms	-> 0.043ms	(ExtraTreesEntr)
Fitting model: WeightedEnsemble_L2 ... Training model for up to 29.73s of the 15.35s of remaining time.
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch': 0.667, 'RandomForestGini': 0.333}
	0.7995	 = Validation score   (balanced_accuracy)
	0.1s	 = Training   runtime
	0.0s	 = Validation runtime
	0.081μs	 = Validation runtime (1 row | 10000 batch size | MARGINAL)
	0.013ms	 = Validation runtime (1 row | 10000 batch size)
	0.081μs	 = Validation runtime (1 row | 10000 batch size | REFIT | MARGINAL)
	0.013ms	 = Validation runtime (1 row | 10000 batch size | REFIT)
AutoGluon training complete, total runtime = 14.76s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 2910.3 rows/s (200 batch size)
Enabling decision threshold calibration (calibrate_decision_threshold='auto', metric is valid, problem_type is 'binary')
Calibrating decision threshold to optimize metric balanced_accuracy | Checking 51 thresholds...
Calibrating decision threshold via fine-grained search | Checking 38 thresholds...
	Base Threshold: 0.500	| val: 0.7995
	Best Threshold: 0.507	| val: 0.8029
Updating predictor.decision_threshold from 0.5 -> 0.507
	This will impact how prediction probabilities are converted to predictions in binary classification.
	Prediction probabilities of the positive class >0.507 will be predicted as the positive class ( >50K). This can significantly impact metric scores.
	You can update this value via `predictor.set_decision_threshold`.
	You can calculate an optimal decision threshold on the validation data via `predictor.calibrate_decision_threshold()`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20251219_142353")
Persisting 3 models in memory. Models will require 0.02% of memory.
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 WeightedEnsemble_L2 0.799513 balanced_accuracy 0.068722 4.882390 0.001014 0.095961 2 True 12
1 NeuralNetTorch 0.798987 balanced_accuracy 0.010645 4.048794 0.010645 4.048794 1 True 10
2 CatBoost 0.777076 balanced_accuracy 0.004531 2.221984 0.004531 2.221984 1 True 5
3 LightGBMXT 0.757468 balanced_accuracy 0.004174 0.706533 0.004174 0.706533 1 True 1
4 NeuralNetFastAI 0.751020 balanced_accuracy 0.009032 1.236252 0.009032 1.236252 1 True 8
5 XGBoost 0.744572 balanced_accuracy 0.006338 0.687824 0.006338 0.687824 1 True 9
6 LightGBM 0.731412 balanced_accuracy 0.003638 0.844523 0.003638 0.844523 1 True 2
7 RandomForestGini 0.731412 balanced_accuracy 0.057063 0.737635 0.057063 0.737635 1 True 3
8 RandomForestEntr 0.728056 balanced_accuracy 0.048098 0.739677 0.048098 0.739677 1 True 4
9 ExtraTreesEntr 0.692196 balanced_accuracy 0.048649 0.744643 0.048649 0.744643 1 True 7
10 LightGBMLarge 0.684959 balanced_accuracy 0.003728 1.249449 0.003728 1.249449 1 True 11
11 ExtraTreesGini 0.682392 balanced_accuracy 0.057365 0.740134 0.057365 0.740134 1 True 6

Now we can test the inference speed of the final model and check if it satisfies the inference constraints.

test_data_batch = test_data.sample(infer_limit_batch_size, replace=True, ignore_index=True)

import time
time_start = time.time()
predictor_infer_limit.predict(test_data_batch)
time_end = time.time()

infer_time_per_row = (time_end - time_start) / len(test_data_batch)
rows_per_second = 1 / infer_time_per_row
infer_time_per_row_ratio = infer_time_per_row / infer_limit
is_constraint_satisfied = infer_time_per_row_ratio <= 1

print(f'Model is able to predict {round(rows_per_second, 1)} rows per second. (User-specified Throughput = {1 / infer_limit})')
print(f'Model uses {round(infer_time_per_row_ratio * 100, 1)}% of infer_limit time per row.')
print(f'Model satisfies inference constraint: {is_constraint_satisfied}')
Model is able to predict 61050.9 rows per second. (User-specified Throughput = 20000.0)
Model uses 32.8% of infer_limit time per row.
Model satisfies inference constraint: True

Using smaller ensembles for prediction

Without having to retrain any models, one can construct alternative ensembles that aggregate individual models’ predictions with different weighting schemes. These ensembles become smaller (and hence faster for prediction) if they assign nonzero weight to less models. You can produce a wide variety of ensembles with different accuracy-speed tradeoffs like this:

additional_ensembles = predictor.fit_weighted_ensemble(expand_pareto_frontier=True)
print("Alternative ensembles you can use for prediction:", additional_ensembles)

predictor.leaderboard(only_pareto_frontier=True)
Alternative ensembles you can use for prediction: ['WeightedEnsemble_L2Best_Pareto1', 'WeightedEnsemble_L2Best_Pareto2', 'WeightedEnsemble_L2Best_Pareto3', 'WeightedEnsemble_L2Best_Pareto4', 'WeightedEnsemble_L2Best_Pareto5', 'WeightedEnsemble_L2Best_Pareto6', 'WeightedEnsemble_L2Best_Pareto7', 'WeightedEnsemble_L2Best_Pareto8', 'WeightedEnsemble_L2Best_Pareto9', 'WeightedEnsemble_L2Best_Pareto10', 'WeightedEnsemble_L3Best_Pareto11', 'WeightedEnsemble_L3Best_Pareto12', 'WeightedEnsemble_L3Best_Pareto13', 'WeightedEnsemble_L3Best_Pareto14', 'WeightedEnsemble_L3Best_Pareto15', 'WeightedEnsemble_L3Best_Pareto16', 'WeightedEnsemble_L3Best_Pareto17', 'WeightedEnsemble_L3Best_Pareto18', 'WeightedEnsemble_L3Best_Pareto19', 'WeightedEnsemble_L3Best_Pareto20', 'WeightedEnsemble_L3Best']
Fitting model: WeightedEnsemble_L2Best_Pareto1 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'LightGBM_BAG_L1': 0.75, 'CatBoost_BAG_L1': 0.25}
	0.7798	 = Validation score   (balanced_accuracy)
	0.04s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto2 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'LightGBM_BAG_L1': 0.75, 'CatBoost_BAG_L1': 0.25}
	0.7798	 = Validation score   (balanced_accuracy)
	0.05s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto3 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'LightGBM_BAG_L1': 0.75, 'CatBoost_BAG_L1': 0.25}
	0.7798	 = Validation score   (balanced_accuracy)
	0.05s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto4 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'XGBoost_BAG_L1': 1.0}
	0.7862	 = Validation score   (balanced_accuracy)
	0.07s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto5 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'XGBoost_BAG_L1': 1.0}
	0.7862	 = Validation score   (balanced_accuracy)
	0.08s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto6 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'XGBoost_BAG_L1': 1.0}
	0.7862	 = Validation score   (balanced_accuracy)
	0.1s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto7 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.11s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto8 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.13s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto9 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.14s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2Best_Pareto10 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.16s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto11 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.8095	 = Validation score   (balanced_accuracy)
	0.17s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto12 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.765, 'CatBoost_BAG_L2': 0.118, 'RandomForestEntr_BAG_L1': 0.059, 'LightGBM_BAG_L2': 0.059}
	0.8103	 = Validation score   (balanced_accuracy)
	0.17s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto13 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.765, 'CatBoost_BAG_L2': 0.118, 'RandomForestEntr_BAG_L1': 0.059, 'LightGBM_BAG_L2': 0.059}
	0.8103	 = Validation score   (balanced_accuracy)
	0.17s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto14 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.765, 'CatBoost_BAG_L2': 0.118, 'RandomForestEntr_BAG_L1': 0.059, 'LightGBM_BAG_L2': 0.059}
	0.8103	 = Validation score   (balanced_accuracy)
	0.17s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto15 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.75, 'XGBoost_BAG_L2': 0.125, 'RandomForestEntr_BAG_L1': 0.083, 'ExtraTreesGini_BAG_L1': 0.042}
	0.8123	 = Validation score   (balanced_accuracy)
	0.19s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto16 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.75, 'XGBoost_BAG_L2': 0.125, 'RandomForestEntr_BAG_L1': 0.083, 'ExtraTreesGini_BAG_L1': 0.042}
	0.8123	 = Validation score   (balanced_accuracy)
	0.18s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto17 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.75, 'XGBoost_BAG_L2': 0.25}
	0.811	 = Validation score   (balanced_accuracy)
	0.18s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto18 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.75, 'XGBoost_BAG_L2': 0.25}
	0.811	 = Validation score   (balanced_accuracy)
	0.18s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto19 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.739, 'XGBoost_BAG_L2': 0.13, 'ExtraTreesEntr_BAG_L2': 0.087, 'NeuralNetFastAI_BAG_L1': 0.043}
	0.8129	 = Validation score   (balanced_accuracy)
	0.18s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best_Pareto20 ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 0.778, 'XGBoost_BAG_L2': 0.111, 'ExtraTreesEntr_BAG_L2': 0.111}
	0.8123	 = Validation score   (balanced_accuracy)
	0.2s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L3Best ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Ensemble Weights: {'NeuralNetTorch_BAG_L2': 0.538, 'XGBoost_BAG_L2': 0.308, 'NeuralNetTorch_BAG_L1': 0.077, 'ExtraTreesEntr_BAG_L2': 0.077}
	0.8151	 = Validation score   (balanced_accuracy)
	0.21s	 = Training   runtime
	0.0s	 = Validation runtime
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 WeightedEnsemble_L3 0.815083 balanced_accuracy 1.314348 65.195856 0.000917 0.248854 3 True 24
1 WeightedEnsemble_L3Best_Pareto19 0.812941 balanced_accuracy 1.152461 38.411389 0.000929 0.183355 3 True 43
2 NeuralNetTorch_BAG_L2 0.812563 balanced_accuracy 1.109388 61.752962 0.161899 26.718968 2 True 22
3 WeightedEnsemble_L3Best_Pareto15 0.812332 balanced_accuracy 1.024754 37.786500 0.000945 0.185296 3 True 39
4 WeightedEnsemble_L3Best_Pareto17 0.810988 balanced_accuracy 1.024751 37.783705 0.000942 0.182501 3 True 41
5 NeuralNetTorch_BAG_L1 0.809518 balanced_accuracy 0.120358 14.543094 0.120358 14.543094 1 True 10
6 XGBoost_BAG_L1 0.786227 balanced_accuracy 0.077999 1.575752 0.077999 1.575752 1 True 9
7 LightGBM_BAG_L1 0.779696 balanced_accuracy 0.053448 1.778871 0.053448 1.778871 1 True 2
8 CatBoost_BAG_L1 0.768712 balanced_accuracy 0.042109 7.078955 0.042109 7.078955 1 True 5

The resulting leaderboard will contain the most accurate model for a given inference-latency. You can select whichever model exhibits acceptable latency from the leaderboard and use it for prediction.

model_for_prediction = additional_ensembles[0]
predictions = predictor.predict(test_data, model=model_for_prediction)
predictor.delete_models(models_to_delete=additional_ensembles)  # delete these extra models so they don't affect rest of tutorial
Deleting model WeightedEnsemble_L2Best_Pareto1. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto1 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto2. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto2 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto3. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto3 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto4. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto4 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto5. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto5 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto6. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto6 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto7. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto7 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto8. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto8 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto9. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto9 will be removed.
Deleting model WeightedEnsemble_L2Best_Pareto10. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L2Best_Pareto10 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto11. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto11 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto12. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto12 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto13. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto13 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto14. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto14 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto15. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto15 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto16. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto16 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto17. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto17 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto18. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto18 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto19. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto19 will be removed.
Deleting model WeightedEnsemble_L3Best_Pareto20. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best_Pareto20 will be removed.
Deleting model WeightedEnsemble_L3Best. All files under /home/ci/autogluon/docs/tutorials/tabular/agModels-predictClass/models/WeightedEnsemble_L3Best will be removed.

Collapsing bagged ensembles via refit_full

For an ensemble predictor trained with bagging (as done above), recall there are ~10 bagged copies of each individual model trained on different train/validation folds. We can collapse this bag of ~10 models into a single model that’s fit to the full dataset, which can greatly reduce its memory/latency requirements (but may also reduce accuracy). Below we refit such a model for each original model but you can alternatively do this for just a particular model by specifying the model argument of refit_full().

refit_model_map = predictor.refit_full()
print("Name of each refit-full model corresponding to a previous bagged ensemble:")
print(refit_model_map)
predictor.leaderboard(test_data)
Name of each refit-full model corresponding to a previous bagged ensemble:
{'LightGBMXT_BAG_L1': 'LightGBMXT_BAG_L1_FULL', 'LightGBM_BAG_L1': 'LightGBM_BAG_L1_FULL', 'RandomForestGini_BAG_L1': 'RandomForestGini_BAG_L1_FULL', 'RandomForestEntr_BAG_L1': 'RandomForestEntr_BAG_L1_FULL', 'CatBoost_BAG_L1': 'CatBoost_BAG_L1_FULL', 'ExtraTreesGini_BAG_L1': 'ExtraTreesGini_BAG_L1_FULL', 'ExtraTreesEntr_BAG_L1': 'ExtraTreesEntr_BAG_L1_FULL', 'NeuralNetFastAI_BAG_L1': 'NeuralNetFastAI_BAG_L1_FULL', 'XGBoost_BAG_L1': 'XGBoost_BAG_L1_FULL', 'NeuralNetTorch_BAG_L1': 'NeuralNetTorch_BAG_L1_FULL', 'LightGBMLarge_BAG_L1': 'LightGBMLarge_BAG_L1_FULL', 'WeightedEnsemble_L2': 'WeightedEnsemble_L2_FULL', 'LightGBMXT_BAG_L2': 'LightGBMXT_BAG_L2_FULL', 'LightGBM_BAG_L2': 'LightGBM_BAG_L2_FULL', 'RandomForestGini_BAG_L2': 'RandomForestGini_BAG_L2_FULL', 'RandomForestEntr_BAG_L2': 'RandomForestEntr_BAG_L2_FULL', 'CatBoost_BAG_L2': 'CatBoost_BAG_L2_FULL', 'ExtraTreesGini_BAG_L2': 'ExtraTreesGini_BAG_L2_FULL', 'ExtraTreesEntr_BAG_L2': 'ExtraTreesEntr_BAG_L2_FULL', 'NeuralNetFastAI_BAG_L2': 'NeuralNetFastAI_BAG_L2_FULL', 'XGBoost_BAG_L2': 'XGBoost_BAG_L2_FULL', 'NeuralNetTorch_BAG_L2': 'NeuralNetTorch_BAG_L2_FULL', 'LightGBMLarge_BAG_L2': 'LightGBMLarge_BAG_L2_FULL', 'WeightedEnsemble_L3': 'WeightedEnsemble_L3_FULL'}
Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...
	Models trained in this way will have the suffix "_FULL" and have NaN validation score.
	This process is not bound by time_limit, but should take less time than the original `predictor.fit` call.
	To learn more, refer to the `.refit_full` method docstring which explains how "_FULL" models differ from normal models.
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.36s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.32s	 = Training   runtime
Fitting model: RandomForestGini_BAG_L1_FULL | Skipping fit via cloning parent ...
	0.71s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L1_FULL | Skipping fit via cloning parent ...
	0.64s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: CatBoost_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0
	0.21s	 = Training   runtime
Fitting model: ExtraTreesGini_BAG_L1_FULL | Skipping fit via cloning parent ...
	0.65s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: ExtraTreesEntr_BAG_L1_FULL | Skipping fit via cloning parent ...
	0.64s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: NeuralNetFastAI_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
Metric balanced_accuracy is not supported by this model - using log_loss instead
	Stopping at the best epoch learned earlier - 9.
	0.43s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: XGBoost_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0
	0.06s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: NeuralNetTorch_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
/home/ci/opt/venv/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py:975: FutureWarning: The parameter `force_int_remainder_cols` is deprecated and will be removed in 1.9. It has no effect. Leave it to its default value to avoid this warning.
  warnings.warn(
	1.56s	 = Training   runtime
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBMLarge_BAG_L1_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.62s	 = Training   runtime
Fitting model: WeightedEnsemble_L2_FULL | Skipping fit via cloning parent ...
	Ensemble Weights: {'NeuralNetTorch_BAG_L1': 1.0}
	0.24s	 = Training   runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBMXT_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.29s	 = Training   runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBM_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.34s	 = Training   runtime
Fitting model: RandomForestGini_BAG_L2_FULL | Skipping fit via cloning parent ...
	0.66s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: RandomForestEntr_BAG_L2_FULL | Skipping fit via cloning parent ...
	0.65s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: CatBoost_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0
	0.25s	 = Training   runtime
Fitting model: ExtraTreesGini_BAG_L2_FULL | Skipping fit via cloning parent ...
	0.66s	 = Training   runtime
	0.12s	 = Validation runtime
Fitting model: ExtraTreesEntr_BAG_L2_FULL | Skipping fit via cloning parent ...
	0.63s	 = Training   runtime
	0.13s	 = Validation runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: NeuralNetFastAI_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
Metric balanced_accuracy is not supported by this model - using log_loss instead
	Stopping at the best epoch learned earlier - 9.
	0.45s	 = Training   runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: XGBoost_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0
	0.06s	 = Training   runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: NeuralNetTorch_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
/home/ci/opt/venv/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py:975: FutureWarning: The parameter `force_int_remainder_cols` is deprecated and will be removed in 1.9. It has no effect. Leave it to its default value to avoid this warning.
  warnings.warn(
	4.03s	 = Training   runtime
Fitting 1 L2 models, fit_strategy="sequential" ...
Fitting model: LightGBMLarge_BAG_L2_FULL ...
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	0.81s	 = Training   runtime
Fitting model: WeightedEnsemble_L3_FULL | Skipping fit via cloning parent ...
	Ensemble Weights: {'NeuralNetTorch_BAG_L2': 0.538, 'XGBoost_BAG_L2': 0.308, 'NeuralNetTorch_BAG_L1': 0.077, 'ExtraTreesEntr_BAG_L2': 0.077}
	0.25s	 = Training   runtime
Updated best model to "WeightedEnsemble_L3_FULL" (Previously "WeightedEnsemble_L3"). AutoGluon will default to using "WeightedEnsemble_L3_FULL" for predict() and predict_proba().
Refit complete, total runtime = 10.71s ... Best model: "WeightedEnsemble_L3_FULL"
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 NeuralNetTorch_BAG_L2 0.786935 0.812563 balanced_accuracy 3.563705 1.109388 61.752962 0.597828 0.161899 26.718968 2 True 22
1 WeightedEnsemble_L3 0.775800 0.815083 balanced_accuracy 3.928262 1.314348 65.195856 0.003051 0.000917 0.248854 3 True 24
2 NeuralNetTorch_BAG_L1 0.767081 0.809518 balanced_accuracy 0.420649 0.120358 14.543094 0.420649 0.120358 14.543094 1 True 10
3 WeightedEnsemble_L2 0.767081 0.809518 balanced_accuracy 0.422477 0.121753 14.781031 0.001828 0.001396 0.237937 2 True 12
4 LightGBMXT_BAG_L1_FULL 0.765134 NaN balanced_accuracy 0.049157 NaN 0.358391 0.049157 NaN 0.358391 1 True 25
5 CatBoost_BAG_L2 0.759693 0.793914 balanced_accuracy 3.011664 0.996059 41.823847 0.045787 0.048570 6.789853 2 True 17
6 LightGBMXT_BAG_L1 0.758639 0.778835 balanced_accuracy 0.387074 0.060408 1.634013 0.387074 0.060408 1.634013 1 True 1
7 XGBoost_BAG_L1_FULL 0.756592 NaN balanced_accuracy 0.028754 NaN 0.060360 0.028754 NaN 0.060360 1 True 33
8 CatBoost_BAG_L1 0.755940 0.768712 balanced_accuracy 0.054019 0.042109 7.078955 0.054019 0.042109 7.078955 1 True 5
9 CatBoost_BAG_L1_FULL 0.755101 NaN balanced_accuracy 0.008105 NaN 0.213901 0.008105 NaN 0.213901 1 True 29
10 XGBoost_BAG_L2 0.754565 0.797148 balanced_accuracy 3.224366 1.023809 37.601204 0.258489 0.076320 2.567210 2 True 21
11 ExtraTreesEntr_BAG_L2_FULL 0.753615 NaN balanced_accuracy 0.840917 NaN 6.194455 0.103344 0.127723 0.626830 2 True 43
12 LightGBMXT_BAG_L2_FULL 0.752470 NaN balanced_accuracy 0.763807 NaN 5.862479 0.026234 NaN 0.294854 2 True 37
13 LightGBM_BAG_L2 0.751866 0.801663 balanced_accuracy 3.193531 1.002546 37.509482 0.227654 0.055056 2.475489 2 True 14
14 LightGBM_BAG_L1_FULL 0.750874 NaN balanced_accuracy 0.034246 NaN 0.315531 0.034246 NaN 0.315531 1 True 26
15 ExtraTreesGini_BAG_L2_FULL 0.749594 NaN balanced_accuracy 0.840671 NaN 6.225716 0.103097 0.124954 0.658091 2 True 42
16 CatBoost_BAG_L2_FULL 0.749363 NaN balanced_accuracy 0.746714 NaN 5.815238 0.009141 NaN 0.247613 2 True 41
17 ExtraTreesEntr_BAG_L2 0.749119 0.758758 balanced_accuracy 3.068895 1.075213 35.660824 0.103018 0.127723 0.626830 2 True 19
18 LightGBMXT_BAG_L2 0.748702 0.786227 balanced_accuracy 3.169814 1.005201 36.667126 0.203937 0.057711 1.633132 2 True 13
19 ExtraTreesGini_BAG_L2 0.748127 0.748257 balanced_accuracy 3.069562 1.072443 35.692085 0.103685 0.124954 0.658091 2 True 18
20 XGBoost_BAG_L1 0.747715 0.786227 balanced_accuracy 0.266386 0.077999 1.575752 0.266386 0.077999 1.575752 1 True 9
21 XGBoost_BAG_L2_FULL 0.747710 NaN balanced_accuracy 0.767184 NaN 5.631575 0.029611 NaN 0.063951 2 True 45
22 RandomForestGini_BAG_L2_FULL 0.746674 NaN balanced_accuracy 0.838880 NaN 6.231582 0.101306 0.121586 0.663958 2 True 39
23 LightGBM_BAG_L2_FULL 0.746411 NaN balanced_accuracy 0.766051 NaN 5.908091 0.028478 NaN 0.340466 2 True 38
24 RandomForestEntr_BAG_L2 0.746406 0.776336 balanced_accuracy 3.067230 1.068860 35.682957 0.101353 0.121370 0.648963 2 True 16
25 RandomForestGini_BAG_L2 0.746051 0.765961 balanced_accuracy 3.067363 1.069075 35.697951 0.101486 0.121586 0.663958 2 True 15
26 LightGBMLarge_BAG_L2 0.745884 0.783413 balanced_accuracy 3.369588 1.017043 39.151877 0.403710 0.069554 4.117883 2 True 23
27 RandomForestEntr_BAG_L2_FULL 0.745864 NaN balanced_accuracy 0.839602 NaN 6.216588 0.102029 0.121370 0.648963 2 True 40
28 RandomForestGini_BAG_L1_FULL 0.745745 NaN balanced_accuracy 0.102112 0.121088 0.708103 0.102112 0.121088 0.708103 1 True 27
29 RandomForestGini_BAG_L1 0.745745 0.748194 balanced_accuracy 0.103528 0.121088 0.708103 0.103528 0.121088 0.708103 1 True 3
30 WeightedEnsemble_L3_FULL 0.745726 NaN balanced_accuracy 0.951661 NaN 10.540162 0.003060 NaN 0.248854 3 True 48
31 LightGBM_BAG_L1 0.745232 0.779696 balanced_accuracy 0.278008 0.053448 1.778871 0.278008 0.053448 1.778871 1 True 2
32 NeuralNetTorch_BAG_L2_FULL 0.744657 NaN balanced_accuracy 0.815646 NaN 9.600528 0.078073 NaN 4.032903 2 True 46
33 LightGBMLarge_BAG_L2_FULL 0.743592 NaN balanced_accuracy 0.782913 NaN 6.382608 0.045340 NaN 0.814983 2 True 47
34 RandomForestEntr_BAG_L1_FULL 0.743132 NaN balanced_accuracy 0.101153 0.121509 0.635930 0.101153 0.121509 0.635930 1 True 28
35 RandomForestEntr_BAG_L1 0.743132 0.756678 balanced_accuracy 0.101367 0.121509 0.635930 0.101367 0.121509 0.635930 1 True 4
36 NeuralNetFastAI_BAG_L2 0.741699 0.748866 balanced_accuracy 4.131179 1.076216 40.729774 1.165302 0.128727 5.695780 2 True 20
37 LightGBMLarge_BAG_L1_FULL 0.737927 NaN balanced_accuracy 0.045058 NaN 0.617012 0.045058 NaN 0.617012 1 True 35
38 LightGBMLarge_BAG_L1 0.735861 0.747333 balanced_accuracy 0.377338 0.060283 2.592734 0.377338 0.060283 2.592734 1 True 11
39 ExtraTreesGini_BAG_L1_FULL 0.731460 NaN balanced_accuracy 0.104100 0.120046 0.646260 0.104100 0.120046 0.646260 1 True 30
40 ExtraTreesGini_BAG_L1 0.731460 0.751491 balanced_accuracy 0.104226 0.120046 0.646260 0.104226 0.120046 0.646260 1 True 6
41 NeuralNetFastAI_BAG_L2_FULL 0.731153 NaN balanced_accuracy 0.882268 NaN 6.022053 0.144695 NaN 0.454428 2 True 44
42 NeuralNetFastAI_BAG_L1 0.728689 0.725071 balanced_accuracy 1.146785 0.108622 5.791547 1.146785 0.108622 5.791547 1 True 8
43 ExtraTreesEntr_BAG_L1 0.725517 0.730553 balanced_accuracy 0.103835 0.121904 0.641470 0.103835 0.121904 0.641470 1 True 7
44 ExtraTreesEntr_BAG_L1_FULL 0.725517 NaN balanced_accuracy 0.104333 0.121904 0.641470 0.104333 0.121904 0.641470 1 True 31
45 NeuralNetFastAI_BAG_L1_FULL 0.709449 NaN balanced_accuracy 0.151972 NaN 0.428511 0.151972 NaN 0.428511 1 True 32
46 NeuralNetTorch_BAG_L1_FULL 0.705997 NaN balanced_accuracy 0.053641 NaN 1.559169 0.053641 NaN 1.559169 1 True 34
47 WeightedEnsemble_L2_FULL 0.705997 NaN balanced_accuracy 0.055459 NaN 1.797106 0.001818 NaN 0.237937 2 True 36

This adds the refit-full models to the leaderboard and we can opt to use any of them for prediction just like any other model. Note pred_time_test and pred_time_val list the time taken to produce predictions with each model (in seconds) on the test/validation data. Since the refit-full models were trained using all of the data, there is no internal validation score (score_val) available for them. You can also call refit_full() with non-bagged models to refit the same models to your full dataset (there won’t be memory/latency gains in this case but test accuracy may improve).

Model distillation

While computationally-favorable, single individual models will usually have lower accuracy than weighted/stacked/bagged ensembles. Model Distillation offers one way to retain the computational benefits of a single model, while enjoying some of the accuracy-boost that comes with ensembling. The idea is to train the individual model (which we can call the student) to mimic the predictions of the full stack ensemble (the teacher). Like refit_full(), the distill() function will produce additional models we can opt to use for prediction.

student_models = predictor.distill(time_limit=120)
predictor.leaderboard(test_data)
Distilling with teacher='WeightedEnsemble_L3_FULL', teacher_preds=soft, augment_method=spunge ...
SPUNGE: Augmenting training data with 4000 synthetic samples for distillation...
Distilling with each of these student models: ['LightGBM_DSTL', 'CatBoost_DSTL', 'RandomForestMSE_DSTL', 'NeuralNetTorch_DSTL']
Fitting 4 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM_DSTL ... Training model for up to 120.00s of the 120.00s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/27.1 GB
	Note: model has different eval_metric than default.
	-0.1206	 = Validation score   (-mean_squared_error)
	0.55s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: CatBoost_DSTL ... Training model for up to 119.44s of the 119.43s of remaining time.
	Fitting with cpus=4, gpus=0
	Note: model has different eval_metric than default.
	-0.1111	 = Validation score   (-mean_squared_error)
	2.96s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: RandomForestMSE_DSTL ... Training model for up to 116.46s of the 116.46s of remaining time.
	Fitting with cpus=8, gpus=0, mem=0.0/27.1 GB
	Note: model has different eval_metric than default.
	-0.1347	 = Validation score   (-mean_squared_error)
	1.38s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: NeuralNetTorch_DSTL ... Training model for up to 114.92s of the 114.92s of remaining time.
	Fitting with cpus=4, gpus=0, mem=0.0/26.9 GB
/home/ci/opt/venv/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py:975: FutureWarning: The parameter `force_int_remainder_cols` is deprecated and will be removed in 1.9. It has no effect. Leave it to its default value to avoid this warning.
  warnings.warn(
	Note: model has different eval_metric than default.
	-0.1354	 = Validation score   (-mean_squared_error)
	7.77s	 = Training   runtime
	0.01s	 = Validation runtime
Distilling with each of these student models: ['WeightedEnsemble_L2_DSTL']
Fitting model: WeightedEnsemble_L2_DSTL ... Training model for up to 120.00s of the 107.13s of remaining time.
	Fitting 1 model on all data | Fitting with cpus=8, gpus=0, mem=0.0/27.0 GB
	Ensemble Weights: {'CatBoost_DSTL': 0.75, 'NeuralNetTorch_DSTL': 0.25}
	Note: model has different eval_metric than default.
	-0.1082	 = Validation score   (-mean_squared_error)
	0.04s	 = Training   runtime
	0.0s	 = Validation runtime
Distilled model leaderboard:
                      model  score_val         eval_metric  pred_time_val   fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0             CatBoost_DSTL   0.728320  mean_squared_error       0.004893   2.960839                0.004893           2.960839            1       True         50
1  WeightedEnsemble_L2_DSTL   0.715423  mean_squared_error       0.016390  10.764567                0.000546           0.036450            2       True         53
2       NeuralNetTorch_DSTL   0.708975  mean_squared_error       0.010951   7.767278                0.010951           7.767278            1       True         52
3      RandomForestMSE_DSTL   0.708185  mean_squared_error       0.058316   1.381613                0.058316           1.381613            1       True         51
4             LightGBM_DSTL   0.659692  mean_squared_error       0.003706   0.553645                0.003706           0.553645            1       True         49
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 NeuralNetTorch_BAG_L2 0.786935 0.812563 balanced_accuracy 3.489274 1.109388 61.752962 0.592900 0.161899 26.718968 2 True 22
1 WeightedEnsemble_L3 0.775800 0.815083 balanced_accuracy 3.842423 1.314348 65.195856 0.002634 0.000917 0.248854 3 True 24
2 NeuralNetTorch_BAG_L1 0.767081 0.809518 balanced_accuracy 0.414217 0.120358 14.543094 0.414217 0.120358 14.543094 1 True 10
3 WeightedEnsemble_L2 0.767081 0.809518 balanced_accuracy 0.416103 0.121753 14.781031 0.001885 0.001396 0.237937 2 True 12
4 LightGBMXT_BAG_L1_FULL 0.765134 NaN balanced_accuracy 0.047482 NaN 0.358391 0.047482 NaN 0.358391 1 True 25
5 CatBoost_BAG_L2 0.759693 0.793914 balanced_accuracy 2.942487 0.996059 41.823847 0.046113 0.048570 6.789853 2 True 17
6 LightGBMXT_BAG_L1 0.758639 0.778835 balanced_accuracy 0.376565 0.060408 1.634013 0.376565 0.060408 1.634013 1 True 1
7 XGBoost_BAG_L1_FULL 0.756592 NaN balanced_accuracy 0.028663 NaN 0.060360 0.028663 NaN 0.060360 1 True 33
8 CatBoost_DSTL 0.756372 0.728320 mean_squared_error 0.012627 0.004893 2.960839 0.012627 0.004893 2.960839 1 True 50
9 CatBoost_BAG_L1 0.755940 0.768712 balanced_accuracy 0.052611 0.042109 7.078955 0.052611 0.042109 7.078955 1 True 5
10 CatBoost_BAG_L1_FULL 0.755101 NaN balanced_accuracy 0.007994 NaN 0.213901 0.007994 NaN 0.213901 1 True 29
11 XGBoost_BAG_L2 0.754565 0.797148 balanced_accuracy 3.143944 1.023809 37.601204 0.247570 0.076320 2.567210 2 True 21
12 ExtraTreesEntr_BAG_L2_FULL 0.753615 NaN balanced_accuracy 0.816323 NaN 6.194455 0.102739 0.127723 0.626830 2 True 43
13 LightGBMXT_BAG_L2_FULL 0.752470 NaN balanced_accuracy 0.739030 NaN 5.862479 0.025446 NaN 0.294854 2 True 37
14 LightGBM_BAG_L2 0.751866 0.801663 balanced_accuracy 3.121256 1.002546 37.509482 0.224882 0.055056 2.475489 2 True 14
15 LightGBM_BAG_L1_FULL 0.750874 NaN balanced_accuracy 0.033995 NaN 0.315531 0.033995 NaN 0.315531 1 True 26
16 ExtraTreesGini_BAG_L2_FULL 0.749594 NaN balanced_accuracy 0.817636 NaN 6.225716 0.104052 0.124954 0.658091 2 True 42
17 CatBoost_BAG_L2_FULL 0.749363 NaN balanced_accuracy 0.722560 NaN 5.815238 0.008976 NaN 0.247613 2 True 41
18 ExtraTreesEntr_BAG_L2 0.749119 0.758758 balanced_accuracy 2.999320 1.075213 35.660824 0.102946 0.127723 0.626830 2 True 19
19 LightGBMXT_BAG_L2 0.748702 0.786227 balanced_accuracy 3.098956 1.005201 36.667126 0.202582 0.057711 1.633132 2 True 13
20 ExtraTreesGini_BAG_L2 0.748127 0.748257 balanced_accuracy 3.000082 1.072443 35.692085 0.103708 0.124954 0.658091 2 True 18
21 XGBoost_BAG_L1 0.747715 0.786227 balanced_accuracy 0.272109 0.077999 1.575752 0.272109 0.077999 1.575752 1 True 9
22 XGBoost_BAG_L2_FULL 0.747710 NaN balanced_accuracy 0.741525 NaN 5.631575 0.027941 NaN 0.063951 2 True 45
23 RandomForestGini_BAG_L2_FULL 0.746674 NaN balanced_accuracy 0.815067 NaN 6.231582 0.101483 0.121586 0.663958 2 True 39
24 LightGBM_BAG_L2_FULL 0.746411 NaN balanced_accuracy 0.740890 NaN 5.908091 0.027306 NaN 0.340466 2 True 38
25 RandomForestEntr_BAG_L2 0.746406 0.776336 balanced_accuracy 2.997384 1.068860 35.682957 0.101010 0.121370 0.648963 2 True 16
26 RandomForestGini_BAG_L2 0.746051 0.765961 balanced_accuracy 2.997029 1.069075 35.697951 0.100655 0.121586 0.663958 2 True 15
27 LightGBMLarge_BAG_L2 0.745884 0.783413 balanced_accuracy 3.298753 1.017043 39.151877 0.402379 0.069554 4.117883 2 True 23
28 RandomForestEntr_BAG_L2_FULL 0.745864 NaN balanced_accuracy 0.814719 NaN 6.216588 0.101135 0.121370 0.648963 2 True 40
29 RandomForestGini_BAG_L1_FULL 0.745745 NaN balanced_accuracy 0.101097 0.121088 0.708103 0.101097 0.121088 0.708103 1 True 27
30 RandomForestGini_BAG_L1 0.745745 0.748194 balanced_accuracy 0.103122 0.121088 0.708103 0.103122 0.121088 0.708103 1 True 3
31 WeightedEnsemble_L3_FULL 0.745726 NaN balanced_accuracy 0.922953 NaN 10.540162 0.002804 NaN 0.248854 3 True 48
32 LightGBM_BAG_L1 0.745232 0.779696 balanced_accuracy 0.282235 0.053448 1.778871 0.282235 0.053448 1.778871 1 True 2
33 NeuralNetTorch_BAG_L2_FULL 0.744657 NaN balanced_accuracy 0.789469 NaN 9.600528 0.075886 NaN 4.032903 2 True 46
34 LightGBMLarge_BAG_L2_FULL 0.743592 NaN balanced_accuracy 0.757672 NaN 6.382608 0.044088 NaN 0.814983 2 True 47
35 RandomForestEntr_BAG_L1_FULL 0.743132 NaN balanced_accuracy 0.100496 0.121509 0.635930 0.100496 0.121509 0.635930 1 True 28
36 RandomForestEntr_BAG_L1 0.743132 0.756678 balanced_accuracy 0.101837 0.121509 0.635930 0.101837 0.121509 0.635930 1 True 4
37 NeuralNetFastAI_BAG_L2 0.741699 0.748866 balanced_accuracy 3.976853 1.076216 40.729774 1.080479 0.128727 5.695780 2 True 20
38 LightGBMLarge_BAG_L1_FULL 0.737927 NaN balanced_accuracy 0.045490 NaN 0.617012 0.045490 NaN 0.617012 1 True 35
39 LightGBMLarge_BAG_L1 0.735861 0.747333 balanced_accuracy 0.376378 0.060283 2.592734 0.376378 0.060283 2.592734 1 True 11
40 WeightedEnsemble_L2_DSTL 0.731613 0.715423 mean_squared_error 0.078718 0.016390 10.764567 0.010504 0.000546 0.036450 2 True 53
41 ExtraTreesGini_BAG_L1_FULL 0.731460 NaN balanced_accuracy 0.103751 0.120046 0.646260 0.103751 0.120046 0.646260 1 True 30
42 ExtraTreesGini_BAG_L1 0.731460 0.751491 balanced_accuracy 0.104640 0.120046 0.646260 0.104640 0.120046 0.646260 1 True 6
43 NeuralNetFastAI_BAG_L2_FULL 0.731153 NaN balanced_accuracy 0.846759 NaN 6.022053 0.133175 NaN 0.454428 2 True 44
44 NeuralNetFastAI_BAG_L1 0.728689 0.725071 balanced_accuracy 1.084603 0.108622 5.791547 1.084603 0.108622 5.791547 1 True 8
45 ExtraTreesEntr_BAG_L1_FULL 0.725517 NaN balanced_accuracy 0.102796 0.121904 0.641470 0.102796 0.121904 0.641470 1 True 31
46 ExtraTreesEntr_BAG_L1 0.725517 0.730553 balanced_accuracy 0.104435 0.121904 0.641470 0.104435 0.121904 0.641470 1 True 7
47 LightGBM_DSTL 0.716672 0.659692 mean_squared_error 0.015232 0.003706 0.553645 0.015232 0.003706 0.553645 1 True 49
48 RandomForestMSE_DSTL 0.715786 0.708185 mean_squared_error 0.224787 0.058316 1.381613 0.224787 0.058316 1.381613 1 True 51
49 NeuralNetFastAI_BAG_L1_FULL 0.709449 NaN balanced_accuracy 0.136764 NaN 0.428511 0.136764 NaN 0.428511 1 True 32
50 NeuralNetTorch_BAG_L1_FULL 0.705997 NaN balanced_accuracy 0.050546 NaN 1.559169 0.050546 NaN 1.559169 1 True 34
51 WeightedEnsemble_L2_FULL 0.705997 NaN balanced_accuracy 0.052323 NaN 1.797106 0.001777 NaN 0.237937 2 True 36
52 NeuralNetTorch_DSTL 0.670008 0.708975 mean_squared_error 0.055588 0.010951 7.767278 0.055588 0.010951 7.767278 1 True 52

While distillation might produce efficient models, we recommend first focusing on refit_full, infer_limit, and persist to try satisfying your requirements, as distillation may come with a significant accuracy drop.

Faster presets or hyperparameters

Instead of trying to speed up a cumbersome trained model at prediction time, if you know inference latency or memory will be an issue at the outset, then you can adjust the training process accordingly to ensure fit() does not produce unwieldy models.

One option is to specify more lightweight presets:

presets = ['good_quality', 'optimize_for_deployment']
predictor_light = TabularPredictor(label=label, eval_metric=metric).fit(train_data, presets=presets, time_limit=30)

You may also exclude specific unwieldy models from being trained at all. Below we exclude models that tend to be slower (K Nearest Neighbors, Neural Networks):

excluded_model_types = ['KNN', 'NN_TORCH']
predictor_light = TabularPredictor(label=label, eval_metric=metric).fit(train_data, excluded_model_types=excluded_model_types, time_limit=30)

(Advanced) Cache preprocessed data

If you are repeatedly predicting on the same data you can cache the preprocessed version of the data and directly send the preprocessed data to predictor.predict for faster inference:

test_data_preprocessed = predictor.transform_features(test_data)

# The following call will be faster than a normal predict call because we are skipping the preprocessing stage.
predictions = predictor.predict(test_data_preprocessed, transform_features=False)

Note that this is only useful in situations where you are repeatedly predicting on the same data. If this significantly speeds up your use-case, consider whether your current approach makes sense or if a cache on the predictions is a better solution.

(Advanced) Disable preprocessing

If you would rather do data preprocessing outside of TabularPredictor, you can disable TabularPredictor’s preprocessing entirely via:

predictor.fit(..., feature_generator=None, feature_metadata=YOUR_CUSTOM_FEATURE_METADATA)

Be warned that this removes ALL guardrails on data sanitization. It is very likely that you will run into errors doing this unless you are very familiar with AutoGluon.

One instance where this can be helpful is if you have many problems that re-use the exact same data with the exact same features. If you had 30 tasks that re-use the same features, you could fit a autogluon.features feature generator once on the data, and then when you need to predict on the 30 tasks, preprocess the data only once and then send the preprocessed data to all 30 predictors.

If you encounter memory issues

To reduce memory usage during training, you may try each of the following strategies individually or combinations of them (these may harm accuracy):

  • In fit(), set excluded_model_types = ['KNN', 'XT' ,'RF'] (or some subset of these models).

  • Try different presets in fit().

  • Text fields in your table require substantial memory for N-gram featurization. To mitigate this in fit(), you can either: (1) add 'ignore_text' to your presets list (to ignore text features), or (2) specify the argument:

from sklearn.feature_extraction.text import CountVectorizer
from autogluon.features.generators import AutoMLPipelineFeatureGenerator
MAX_NGRAM = 1000
feature_generator = AutoMLPipelineFeatureGenerator(vectorizer=CountVectorizer(min_df=30, ngram_range=(1, 3), max_features=MAX_NGRAM, dtype=np.uint8))
predictor = TabularPredictor(...).fit(..., feature_generator=feature_generator)

For example using MAX_NGRAM = 1000 (try various values under 10000 to reduce the number of N-gram features used to represent each text field)

In addition to reducing memory usage, many of the above strategies can also be used to reduce training times.

To reduce memory usage during inference:

  • If trying to produce predictions for a large test dataset, break the test data into smaller chunks as demonstrated in FAQ.

  • If models have been previously persisted in memory but inference-speed is not a major concern, call predictor.unpersist().

  • If models have been previously persisted in memory, bagging was used in fit(), and inference-speed is a concern: call predictor.refit_full() and use one of the refit-full models for prediction (ensure this is the only model persisted in memory).

If you encounter disk space issues

To reduce disk usage, you may try each of the following strategies individually or combinations of them:

  • Make sure to delete all predictor.path folders from previous fit() runs! These can eat up your free space if you call fit() many times. If you didn’t specify path, AutoGluon still automatically saved its models to a folder called: “AutogluonModels/ag-[TIMESTAMP]”, where TIMESTAMP records when fit() was called, so make sure to also delete these folders if you run low on free space.

  • Call predictor.save_space() to delete auxiliary files produced during fit().

  • Call predictor.delete_models(models_to_keep='best') if you only intend to use this predictor for inference going forward (will delete files required for non-prediction-related functionality like fit_summary).

  • In fit(), you can add 'optimize_for_deployment' to the presets list, which will automatically invoke the previous two strategies after training.

  • Most of the above strategies to reduce memory usage will also reduce disk usage (but may harm accuracy).

References

The following paper describes how AutoGluon internally operates on tabular data:

Erickson et al. AutoGluon-Tabular: Robust and Accurate AutoML for Structured Data. Arxiv, 2020.

Next Steps

If you are interested in deployment optimization, refer to the Predicting Columns in a Table - Deployment Optimization tutorial.