Predicting Multiple Columns in a Table (Multi-Label Prediction)

In multi-label prediction, we wish to predict multiple columns of a table (i.e. labels) based on the values in the remaining columns. Here we present a simple strategy to do this with AutoGluon, which simply maintains a separate TabularPredictor object for each column being predicted. Correlations between labels can be accounted for in predictions by imposing an order on the labels and allowing the TabularPredictor for each label to condition on the predicted values for labels that appeared earlier in the order.

MultilabelPredictor Class

We start by defining a custom MultilabelPredictor class to manage a collection of TabularPredictor objects, one for each label. You can use the MultilabelPredictor similarly to an individual TabularPredictor, except it operates on multiple labels rather than one.

from autogluon.tabular import TabularDataset, TabularPredictor
from autogluon.core.utils.utils import setup_outputdir
from autogluon.core.utils.loaders import load_pkl
from autogluon.core.utils.savers import save_pkl
import os.path

class MultilabelPredictor():
    """ Tabular Predictor for predicting multiple columns in table.
        Creates multiple TabularPredictor objects which you can also use individually.
        You can access the TabularPredictor for a particular label via: `multilabel_predictor.get_predictor(label_i)`

        Parameters
        ----------
        labels : List[str]
            The ith element of this list is the column (i.e. `label`) predicted by the ith TabularPredictor stored in this object.
        path : str
            Path to directory where models and intermediate outputs should be saved.
            If unspecified, a time-stamped folder called "AutogluonModels/ag-[TIMESTAMP]" will be created in the working directory to store all models.
            Note: To call `fit()` twice and save all results of each fit, you must specify different `path` locations or don't specify `path` at all.
            Otherwise files from first `fit()` will be overwritten by second `fit()`.
            Caution: when predicting many labels, this directory may grow large as it needs to store many TabularPredictors.
        problem_types : List[str]
            The ith element is the `problem_type` for the ith TabularPredictor stored in this object.
        eval_metrics : List[str]
            The ith element is the `eval_metric` for the ith TabularPredictor stored in this object.
        consider_labels_correlation : bool
            Whether the predictions of multiple labels should account for label correlations or predict each label independently of the others.
            If True, the ordering of `labels` may affect resulting accuracy as each label is predicted conditional on the previous labels appearing earlier in this list (i.e. in an auto-regressive fashion).
            Set to False if during inference you may want to individually use just the ith TabularPredictor without predicting all the other labels.
        kwargs :
            Arguments passed into the initialization of each TabularPredictor.

    """

    multi_predictor_file = 'multilabel_predictor.pkl'

    def __init__(self, labels, path, problem_types=None, eval_metrics=None, consider_labels_correlation=True, **kwargs):
        if len(labels) < 2:
            raise ValueError("MultilabelPredictor is only intended for predicting MULTIPLE labels (columns), use TabularPredictor for predicting one label (column).")
        self.path = setup_outputdir(path, warn_if_exist=False)
        self.labels = labels
        self.consider_labels_correlation = consider_labels_correlation
        self.predictors = {}  # key = label, value = TabularPredictor or str path to the TabularPredictor for this label
        if eval_metrics is None:
            self.eval_metrics = {}
        else:
            self.eval_metrics = {labels[i] : eval_metrics[i] for i in range(len(labels))}
        problem_type = None
        eval_metric = None
        for i in range(len(labels)):
            label = labels[i]
            path_i = self.path + "Predictor_" + label
            if problem_types is not None:
                problem_type = problem_types[i]
            if eval_metrics is not None:
                eval_metric = self.eval_metrics[label]
            self.predictors[label] = TabularPredictor(label=label, problem_type=problem_type, eval_metric=eval_metric, path=path_i, **kwargs)

    def fit(self, train_data, tuning_data=None, **kwargs):
        """ Fits a separate TabularPredictor to predict each of the labels.

            Parameters
            ----------
            train_data, tuning_data : str or autogluon.tabular.TabularDataset or pd.DataFrame
                See documentation for `TabularPredictor.fit()`.
            kwargs :
                Arguments passed into the `fit()` call for each TabularPredictor.
        """
        if isinstance(train_data, str):
            train_data = TabularDataset(train_data)
        if tuning_data is not None and isinstance(tuning_data, str):
            tuning_data = TabularDataset(tuning_data)
        train_data_og = train_data.copy()
        if tuning_data is not None:
            tuning_data_og = tuning_data.copy()
        else:
            tuning_data_og = None
        save_metrics = len(self.eval_metrics) == 0
        for i in range(len(self.labels)):
            label = self.labels[i]
            predictor = self.get_predictor(label)
            if not self.consider_labels_correlation:
                labels_to_drop = [l for l in self.labels if l != label]
            else:
                labels_to_drop = [self.labels[j] for j in range(i+1, len(self.labels))]
            train_data = train_data_og.drop(labels_to_drop, axis=1)
            if tuning_data is not None:
                tuning_data = tuning_data_og.drop(labels_to_drop, axis=1)
            print(f"Fitting TabularPredictor for label: {label} ...")
            predictor.fit(train_data=train_data, tuning_data=tuning_data, **kwargs)
            self.predictors[label] = predictor.path
            if save_metrics:
                self.eval_metrics[label] = predictor.eval_metric
        self.save()

    def predict(self, data, **kwargs):
        """ Returns DataFrame with label columns containing predictions for each label.

            Parameters
            ----------
            data : str or autogluon.tabular.TabularDataset or pd.DataFrame
                Data to make predictions for. If label columns are present in this data, they will be ignored. See documentation for `TabularPredictor.predict()`.
            kwargs :
                Arguments passed into the predict() call for each TabularPredictor.
        """
        return self._predict(data, as_proba=False, **kwargs)

    def predict_proba(self, data, **kwargs):
        """ Returns dict where each key is a label and the corresponding value is the `predict_proba()` output for just that label.

            Parameters
            ----------
            data : str or autogluon.tabular.TabularDataset or pd.DataFrame
                Data to make predictions for. See documentation for `TabularPredictor.predict()` and `TabularPredictor.predict_proba()`.
            kwargs :
                Arguments passed into the `predict_proba()` call for each TabularPredictor (also passed into a `predict()` call).
        """
        return self._predict(data, as_proba=True, **kwargs)

    def evaluate(self, data, **kwargs):
        """ Returns dict where each key is a label and the corresponding value is the `evaluate()` output for just that label.

            Parameters
            ----------
            data : str or autogluon.tabular.TabularDataset or pd.DataFrame
                Data to evalate predictions of all labels for, must contain all labels as columns. See documentation for `TabularPredictor.evaluate()`.
            kwargs :
                Arguments passed into the `evaluate()` call for each TabularPredictor (also passed into the `predict()` call).
        """
        data = self._get_data(data)
        eval_dict = {}
        for label in self.labels:
            print(f"Evaluating TabularPredictor for label: {label} ...")
            predictor = self.get_predictor(label)
            eval_dict[label] = predictor.evaluate(data, **kwargs)
            if self.consider_labels_correlation:
                data[label] = predictor.predict(data, **kwargs)
        return eval_dict

    def save(self):
        """ Save MultilabelPredictor to disk. """
        for label in self.labels:
            if not isinstance(self.predictors[label], str):
                self.predictors[label] = self.predictors[label].path
        save_pkl.save(path=self.path+self.multi_predictor_file, object=self)
        print(f"MultilabelPredictor saved to disk. Load with: MultilabelPredictor.load('{self.path}')")

    @classmethod
    def load(cls, path):
        """ Load MultilabelPredictor from disk `path` previously specified when creating this MultilabelPredictor. """
        path = os.path.expanduser(path)
        if path[-1] != os.path.sep:
            path = path + os.path.sep
        return load_pkl.load(path=path+cls.multi_predictor_file)

    def get_predictor(self, label):
        """ Returns TabularPredictor which is used to predict this label. """
        predictor = self.predictors[label]
        if isinstance(predictor, str):
            return TabularPredictor.load(path=predictor)
        return predictor

    def _get_data(self, data):
        if isinstance(data, str):
            return TabularDataset(data)
        return data.copy()

    def _predict(self, data, as_proba=False, **kwargs):
        data = self._get_data(data)
        if as_proba:
            predproba_dict = {}
        for label in self.labels:
            print(f"Predicting with TabularPredictor for label: {label} ...")
            predictor = self.get_predictor(label)
            if as_proba:
                predproba_dict[label] = predictor.predict_proba(data, as_multiclass=True, **kwargs)
            data[label] = predictor.predict(data, **kwargs)
        if not as_proba:
            return data[self.labels]
        else:
            return predproba_dict

Training

Let’s now apply our multi-label predictor to predict multiple columns in a data table. We first train models to predict each of the labels.

train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
subsample_size = 500  # subsample subset of data for faster demo, try setting this to much larger values
train_data = train_data.sample(n=subsample_size, random_state=0)
train_data.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country class
6118 51 Private 39264 Some-college 10 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States >50K
23204 58 Private 51662 10th 6 Married-civ-spouse Other-service Wife White Female 0 0 8 United-States <=50K
29590 40 Private 326310 Some-college 10 Married-civ-spouse Craft-repair Husband White Male 0 0 44 United-States <=50K
18116 37 Private 222450 HS-grad 9 Never-married Sales Not-in-family White Male 0 2339 40 El-Salvador <=50K
33964 62 Private 109190 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 15024 0 40 United-States >50K
labels = ['education-num','education','class']  # which columns to predict based on the others
problem_types = ['regression','multiclass','binary']  # type of each prediction problem
save_path = 'agModels-predictEducationClass'  # specifies folder to store trained models

time_limit = 5  # how many seconds to train the TabularPredictor for each label, set much larger in your applications!
multi_predictor = MultilabelPredictor(labels=labels, problem_types=problem_types, path=save_path)
multi_predictor.fit(train_data, time_limit=time_limit)
Beginning AutoGluon training ... Time limit = 5s
AutoGluon will save models to "agModels-predictEducationClass/Predictor_education-num/"
AutoGluon Version:  0.3.0b20210827
Train Data Rows:    500
Train Data Columns: 12
Preprocessing data ...
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
    Available Memory:                    22173.95 MB
    Train Data (Original)  Memory Usage: 0.26 MB (0.0% of available memory)
    Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
    Stage 1 Generators:
            Fitting AsTypeFeatureGenerator...
    Stage 2 Generators:
            Fitting FillNaFeatureGenerator...
    Stage 3 Generators:
            Fitting IdentityFeatureGenerator...
            Fitting CategoryFeatureGenerator...
                    Fitting CategoryMemoryMinimizeFeatureGenerator...
    Stage 4 Generators:
            Fitting DropUniqueFeatureGenerator...
    Types of features in original data (raw dtype, special dtypes):
            ('int', [])    : 5 | ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
            ('object', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]
    Types of features in processed data (raw dtype, special dtypes):
            ('category', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]
            ('int', [])      : 5 | ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
    0.1s = Fit runtime
    12 features in original data used to generate 12 features in processed data.
    Train Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.07s ...
AutoGluon will gauge predictive performance using evaluation metric: 'root_mean_squared_error'
    To change this, specify the eval_metric argument of fit()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100
Fitting 11 L1 models ...
Fitting model: KNeighborsUnif ... Training model for up to 4.93s of the 4.93s of remaining time.
    -2.703   = Validation score   (root_mean_squared_error)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: KNeighborsDist ... Training model for up to 4.82s of the 4.82s of remaining time.
Fitting TabularPredictor for label: education-num ...
    -2.7447  = Validation score   (root_mean_squared_error)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: LightGBMXT ... Training model for up to 4.71s of the 4.71s of remaining time.
    -2.2917  = Validation score   (root_mean_squared_error)
    0.44s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: LightGBM ... Training model for up to 4.26s of the 4.26s of remaining time.
    -2.3176  = Validation score   (root_mean_squared_error)
    0.15s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: RandomForestMSE ... Training model for up to 4.1s of the 4.1s of remaining time.
    -2.2527  = Validation score   (root_mean_squared_error)
    0.51s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: CatBoost ... Training model for up to 3.47s of the 3.46s of remaining time.
    -2.1162  = Validation score   (root_mean_squared_error)
    0.81s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: ExtraTreesMSE ... Training model for up to 2.64s of the 2.64s of remaining time.
    -2.3301  = Validation score   (root_mean_squared_error)
    0.5s     = Training   runtime
    0.11s    = Validation runtime
Fitting model: NeuralNetFastAI ... Training model for up to 2.01s of the 2.01s of remaining time.
    Ran out of time, stopping training early.
    -2.6124  = Validation score   (root_mean_squared_error)
    4.95s    = Training   runtime
    0.02s    = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 4.93s of the -3.45s of remaining time.
    -2.1055  = Validation score   (root_mean_squared_error)
    0.22s    = Training   runtime
    0.0s     = Validation runtime
AutoGluon training complete, total runtime = 8.68s ...
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("agModels-predictEducationClass/Predictor_education-num/")
Beginning AutoGluon training ... Time limit = 5s
AutoGluon will save models to "agModels-predictEducationClass/Predictor_education/"
AutoGluon Version:  0.3.0b20210827
Train Data Rows:    500
Train Data Columns: 13
Preprocessing data ...
Warning: Some classes in the training set have fewer than 10 examples. AutoGluon will only keep 11 out of 15 classes for training and will not try to predict the rare classes. To keep more classes, increase the number of datapoints from these rare classes in the training data or reduce label_count_threshold.
Fraction of data from classes with at least 10 examples that will be kept for training models: 0.976
Train Data Class Count: 11
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
    Available Memory:                    18905.87 MB
    Train Data (Original)  Memory Usage: 0.25 MB (0.0% of available memory)
    Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
    Stage 1 Generators:
            Fitting AsTypeFeatureGenerator...
    Stage 2 Generators:
            Fitting FillNaFeatureGenerator...
    Stage 3 Generators:
            Fitting IdentityFeatureGenerator...
            Fitting CategoryFeatureGenerator...
                    Fitting CategoryMemoryMinimizeFeatureGenerator...
    Stage 4 Generators:
            Fitting DropUniqueFeatureGenerator...
    Types of features in original data (raw dtype, special dtypes):
            ('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
            ('object', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]
    Types of features in processed data (raw dtype, special dtypes):
            ('category', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]
            ('int', [])      : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
    0.1s = Fit runtime
    13 features in original data used to generate 13 features in processed data.
    Train Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.08s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
    To change this, specify the eval_metric argument of fit()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 390, Val Rows: 98
Fitting 13 L1 models ...
Fitting model: KNeighborsUnif ... Training model for up to 4.92s of the 4.92s of remaining time.
    0.2653   = Validation score   (accuracy)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: KNeighborsDist ... Training model for up to 4.82s of the 4.81s of remaining time.
Fitting TabularPredictor for label: education ...
    0.2347   = Validation score   (accuracy)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: NeuralNetFastAI ... Training model for up to 4.71s of the 4.71s of remaining time.
    0.8265   = Validation score   (accuracy)
    0.57s    = Training   runtime
    0.02s    = Validation runtime
Fitting model: LightGBMXT ... Training model for up to 4.11s of the 4.11s of remaining time.
    0.9796   = Validation score   (accuracy)
    0.62s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: LightGBM ... Training model for up to 3.45s of the 3.45s of remaining time.
    1.0      = Validation score   (accuracy)
    0.43s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: RandomForestGini ... Training model for up to 3.01s of the 3.0s of remaining time.
    Warning: Reducing model 'n_estimators' from 300 -> 115 due to low time. Expected time usage reduced from 7.8s -> 3.0s...
    0.9286   = Validation score   (accuracy)
    0.35s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: RandomForestEntr ... Training model for up to 2.54s of the 2.54s of remaining time.
    Warning: Reducing model 'n_estimators' from 300 -> 98 due to low time. Expected time usage reduced from 7.8s -> 2.5s...
    0.8571   = Validation score   (accuracy)
    0.35s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: CatBoost ... Training model for up to 2.08s of the 2.08s of remaining time.
    1.0      = Validation score   (accuracy)
    4.33s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 4.92s of the -2.77s of remaining time.
    1.0      = Validation score   (accuracy)
    0.18s    = Training   runtime
    0.0s     = Validation runtime
AutoGluon training complete, total runtime = 7.96s ...
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("agModels-predictEducationClass/Predictor_education/")
Beginning AutoGluon training ... Time limit = 5s
AutoGluon will save models to "agModels-predictEducationClass/Predictor_class/"
AutoGluon Version:  0.3.0b20210827
Train Data Rows:    500
Train Data Columns: 14
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
    Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
    To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
    Available Memory:                    18876.19 MB
    Train Data (Original)  Memory Usage: 0.29 MB (0.0% of available memory)
    Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
    Stage 1 Generators:
            Fitting AsTypeFeatureGenerator...
    Stage 2 Generators:
            Fitting FillNaFeatureGenerator...
    Stage 3 Generators:
            Fitting IdentityFeatureGenerator...
            Fitting CategoryFeatureGenerator...
                    Fitting CategoryMemoryMinimizeFeatureGenerator...
    Stage 4 Generators:
            Fitting DropUniqueFeatureGenerator...
    Types of features in original data (raw dtype, special dtypes):
            ('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
            ('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
    Types of features in processed data (raw dtype, special dtypes):
            ('category', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
            ('int', [])      : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
    0.1s = Fit runtime
    14 features in original data used to generate 14 features in processed data.
    Train Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.08s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
    To change this, specify the eval_metric argument of fit()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100
Fitting 13 L1 models ...
Fitting model: KNeighborsUnif ... Training model for up to 4.92s of the 4.92s of remaining time.
    0.73     = Validation score   (accuracy)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: KNeighborsDist ... Training model for up to 4.81s of the 4.81s of remaining time.
Fitting TabularPredictor for label: class ...
    0.65     = Validation score   (accuracy)
    0.0s     = Training   runtime
    0.1s     = Validation runtime
Fitting model: LightGBMXT ... Training model for up to 4.7s of the 4.7s of remaining time.
    0.83     = Validation score   (accuracy)
    0.15s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: LightGBM ... Training model for up to 4.54s of the 4.54s of remaining time.
    0.85     = Validation score   (accuracy)
    0.18s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: RandomForestGini ... Training model for up to 4.35s of the 4.35s of remaining time.
    0.84     = Validation score   (accuracy)
    0.51s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: RandomForestEntr ... Training model for up to 3.72s of the 3.72s of remaining time.
    0.82     = Validation score   (accuracy)
    0.61s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: CatBoost ... Training model for up to 2.99s of the 2.99s of remaining time.
    0.84     = Validation score   (accuracy)
    0.43s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: ExtraTreesGini ... Training model for up to 2.55s of the 2.55s of remaining time.
    0.82     = Validation score   (accuracy)
    0.61s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: ExtraTreesEntr ... Training model for up to 1.82s of the 1.82s of remaining time.
    0.82     = Validation score   (accuracy)
    0.61s    = Training   runtime
    0.11s    = Validation runtime
Fitting model: NeuralNetFastAI ... Training model for up to 1.1s of the 1.1s of remaining time.
    0.83     = Validation score   (accuracy)
    0.56s    = Training   runtime
    0.02s    = Validation runtime
Fitting model: XGBoost ... Training model for up to 0.51s of the 0.51s of remaining time.
    0.87     = Validation score   (accuracy)
    0.18s    = Training   runtime
    0.01s    = Validation runtime
Fitting model: NeuralNetMXNet ... Training model for up to 0.3s of the 0.3s of remaining time.
    Time limit exceeded... Skipping NeuralNetMXNet.
Fitting model: WeightedEnsemble_L2 ... Training model for up to 4.92s of the -1.65s of remaining time.
    0.87     = Validation score   (accuracy)
    0.24s    = Training   runtime
    0.0s     = Validation runtime
AutoGluon training complete, total runtime = 6.9s ...
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("agModels-predictEducationClass/Predictor_class/")
MultilabelPredictor saved to disk. Load with: MultilabelPredictor.load('agModels-predictEducationClass/')

Inference and Evaluation

After training, you can easily use the MultilabelPredictor to predict all labels in new data:

test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
test_data = test_data.sample(n=subsample_size, random_state=0)
test_data_nolab = test_data.drop(columns=labels)  # unnecessary, just to demonstrate we're not cheating here
test_data_nolab.head()
Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv | Columns = 15 / 15 | Rows = 9769 -> 9769
age workclass fnlwgt marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country
5454 41 Self-emp-not-inc 408498 Married-civ-spouse Exec-managerial Husband White Male 0 0 50 United-States
6111 39 Private 746786 Married-civ-spouse Prof-specialty Husband White Male 0 0 55 United-States
5282 50 Private 62593 Married-civ-spouse Farming-fishing Husband Asian-Pac-Islander Male 0 0 40 United-States
3046 31 Private 248178 Married-civ-spouse Other-service Husband Black Male 0 0 35 United-States
2162 43 State-gov 52849 Married-civ-spouse Prof-specialty Husband White Male 0 0 40 United-States
multi_predictor = MultilabelPredictor.load(save_path)  # unnecessary, just demonstrates how to load previously-trained multilabel predictor from file

predictions = multi_predictor.predict(test_data_nolab)
print("Predictions:  \n", predictions)
Predicting with TabularPredictor for label: education-num ...
Predicting with TabularPredictor for label: education ...
Predicting with TabularPredictor for label: class ...
Predictions:
       education-num      education   class
5454      11.055421      Assoc-voc    >50K
6111      12.239116        HS-grad    >50K
5282       9.289109        HS-grad    >50K
3046       8.937429           11th   <=50K
2162      12.659792        HS-grad    >50K
...             ...            ...     ...
6965       9.956499        HS-grad    >50K
4762       8.792459           11th   <=50K
234       10.649448   Some-college   <=50K
6291      10.558071   Some-college    >50K
9575       9.509919        HS-grad    >50K

[500 rows x 3 columns]

We can also easily evaluate the performance of our predictions if our new data contain the ground truth labels:

evaluations = multi_predictor.evaluate(test_data)
print(evaluations)
print("Evaluated using metrics:", multi_predictor.eval_metrics)
Evaluating TabularPredictor for label: education-num ...
Evaluation: root_mean_squared_error on test data: -2.1659743554441215
    Note: Scores are always higher_is_better. This metric score can be multiplied by -1 to get the metric value.
Evaluations on test data:
{
    "root_mean_squared_error": -2.1659743554441215,
    "mean_squared_error": -4.691444908441578,
    "mean_absolute_error": -1.6176969652175903,
    "r2": 0.39337608971938165,
    "pearsonr": 0.641863702661109,
    "median_absolute_error": -1.2806296348571777
}
Evaluation: accuracy on test data: 0.214
Evaluations on test data:
{
    "accuracy": 0.214,
    "balanced_accuracy": 0.08723911014150755,
    "mcc": 0.03847538003426249
}
Evaluation: accuracy on test data: 0.814
Evaluations on test data:
{
    "accuracy": 0.814,
    "balanced_accuracy": 0.7077979063499028,
    "mcc": 0.4733035807264,
    "roc_auc": 0.8485446833406465,
    "f1": 0.5753424657534246,
    "precision": 0.7,
    "recall": 0.4883720930232558
}
Evaluating TabularPredictor for label: education ...
Evaluating TabularPredictor for label: class ...
{'education-num': {'root_mean_squared_error': -2.1659743554441215, 'mean_squared_error': -4.691444908441578, 'mean_absolute_error': -1.6176969652175903, 'r2': 0.39337608971938165, 'pearsonr': 0.641863702661109, 'median_absolute_error': -1.2806296348571777}, 'education': {'accuracy': 0.214, 'balanced_accuracy': 0.08723911014150755, 'mcc': 0.03847538003426249}, 'class': {'accuracy': 0.814, 'balanced_accuracy': 0.7077979063499028, 'mcc': 0.4733035807264, 'roc_auc': 0.8485446833406465, 'f1': 0.5753424657534246, 'precision': 0.7, 'recall': 0.4883720930232558}}
Evaluated using metrics: {'education-num': root_mean_squared_error, 'education': accuracy, 'class': accuracy}

Accessing the TabularPredictor for One Label

We can also directly work with the TabularPredictor for any one of the labels as follows. However we recommend you set consider_labels_correlation=False before training if you later plan to use an individual TabularPredictor to predict just one label rather than all of the labels predicted by the MultilabelPredictor.

predictor_class = multi_predictor.get_predictor('class')
predictor_class.leaderboard(silent=True)
model score_val pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 XGBoost 0.87 0.006115 0.181115 0.006115 0.181115 1 True 11
1 WeightedEnsemble_L2 0.87 0.006654 0.416689 0.000539 0.235574 2 True 12
2 LightGBM 0.85 0.007534 0.178237 0.007534 0.178237 1 True 4
3 CatBoost 0.84 0.009202 0.428102 0.009202 0.428102 1 True 7
4 RandomForestGini 0.84 0.106746 0.508887 0.106746 0.508887 1 True 5
5 LightGBMXT 0.83 0.007622 0.149821 0.007622 0.149821 1 True 3
6 NeuralNetFastAI 0.83 0.015927 0.560520 0.015927 0.560520 1 True 10
7 RandomForestEntr 0.82 0.106828 0.609127 0.106828 0.609127 1 True 6
8 ExtraTreesEntr 0.82 0.106994 0.606519 0.106994 0.606519 1 True 9
9 ExtraTreesGini 0.82 0.107108 0.606220 0.107108 0.606220 1 True 8
10 KNeighborsUnif 0.73 0.102892 0.003825 0.102892 0.003825 1 True 1
11 KNeighborsDist 0.65 0.103310 0.003794 0.103310 0.003794 1 True 2

Tips

In order to obtain the best predictions, you should generally add the following arguments to MultilabelPredictor.fit():

  1. Specify eval_metrics to the metrics you will use to evaluate predictions for each label

  2. Specify presets='best_quality' to tell AutoGluon you care about predictive performance more than latency/memory usage, which will utilize stack ensembling when predicting each label.

If you find that too much memory/disk is being used, try calling MultilabelPredictor.fit() with additional arguments discussed under “If you encounter memory issues” in the In Depth Tutorial or “If you encounter disk space issues”.

If you find inference too slow, you can try the strategies discussed under “Accelerating Inference” in the In Depth Tutorial. In particular, simply try specifying the following preset in MultilabelPredictor.fit(): presets = ['good_quality_faster_inference_only_refit', 'optimize_for_deployment']