{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "390608d2", "metadata": {}, "source": [ "# Predicting Multiple Columns in a Table (Multi-Label Prediction)\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/tabular/advanced/tabular-multilabel.ipynb)\n", "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/tabular/advanced/tabular-multilabel.ipynb)\n", "\n", "\n", "\n", "In multi-label prediction, we wish to predict multiple columns of a table (i.e. labels) based on the values in the remaining columns. Here we present a simple strategy to do this with AutoGluon, which simply maintains a separate [TabularPredictor](../../../api/autogluon.tabular.TabularPredictor.rst) object for each column being predicted. Correlations between labels can be accounted for in predictions by imposing an order on the labels and allowing the `TabularPredictor` for each label to condition on the predicted values for labels that appeared earlier in the order.\n", "\n", "## MultilabelPredictor Class\n", "\n", "We start by defining a custom `MultilabelPredictor` class to manage a collection of `TabularPredictor` objects, one for each label. You can use the `MultilabelPredictor` similarly to an individual `TabularPredictor`, except it operates on multiple labels rather than one." ] }, { "cell_type": "code", "execution_count": null, "id": "aa00faab-252f-44c9-b8f7-57131aa8251c", "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "!pip install autogluon.tabular[all]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "329d8dea", "metadata": {}, "outputs": [], "source": [ "from autogluon.tabular import TabularDataset, TabularPredictor\n", "from autogluon.common.utils.utils import setup_outputdir\n", "from autogluon.core.utils.loaders import load_pkl\n", "from autogluon.core.utils.savers import save_pkl\n", "import os.path\n", "\n", "class MultilabelPredictor:\n", " \"\"\" Tabular Predictor for predicting multiple columns in table.\n", " Creates multiple TabularPredictor objects which you can also use individually.\n", " You can access the TabularPredictor for a particular label via: `multilabel_predictor.get_predictor(label_i)`\n", "\n", " Parameters\n", " ----------\n", " labels : List[str]\n", " The ith element of this list is the column (i.e. `label`) predicted by the ith TabularPredictor stored in this object.\n", " path : str, default = None\n", " Path to directory where models and intermediate outputs should be saved.\n", " If unspecified, a time-stamped folder called \"AutogluonModels/ag-[TIMESTAMP]\" will be created in the working directory to store all models.\n", " Note: To call `fit()` twice and save all results of each fit, you must specify different `path` locations or don't specify `path` at all.\n", " Otherwise files from first `fit()` will be overwritten by second `fit()`.\n", " Caution: when predicting many labels, this directory may grow large as it needs to store many TabularPredictors.\n", " problem_types : List[str], default = None\n", " The ith element is the `problem_type` for the ith TabularPredictor stored in this object.\n", " eval_metrics : List[str], default = None\n", " The ith element is the `eval_metric` for the ith TabularPredictor stored in this object.\n", " consider_labels_correlation : bool, default = True\n", " Whether the predictions of multiple labels should account for label correlations or predict each label independently of the others.\n", " If True, the ordering of `labels` may affect resulting accuracy as each label is predicted conditional on the previous labels appearing earlier in this list (i.e. in an auto-regressive fashion).\n", " Set to False if during inference you may want to individually use just the ith TabularPredictor without predicting all the other labels.\n", " kwargs :\n", " Arguments passed into the initialization of each TabularPredictor.\n", "\n", " \"\"\"\n", "\n", " multi_predictor_file = 'multilabel_predictor.pkl'\n", "\n", " def __init__(self, labels, path=None, problem_types=None, eval_metrics=None, consider_labels_correlation=True, **kwargs):\n", " if len(labels) < 2:\n", " raise ValueError(\"MultilabelPredictor is only intended for predicting MULTIPLE labels (columns), use TabularPredictor for predicting one label (column).\")\n", " if (problem_types is not None) and (len(problem_types) != len(labels)):\n", " raise ValueError(\"If provided, `problem_types` must have same length as `labels`\")\n", " if (eval_metrics is not None) and (len(eval_metrics) != len(labels)):\n", " raise ValueError(\"If provided, `eval_metrics` must have same length as `labels`\")\n", " self.path = setup_outputdir(path, warn_if_exist=False)\n", " self.labels = labels\n", " self.consider_labels_correlation = consider_labels_correlation\n", " self.predictors = {} # key = label, value = TabularPredictor or str path to the TabularPredictor for this label\n", " if eval_metrics is None:\n", " self.eval_metrics = {}\n", " else:\n", " self.eval_metrics = {labels[i] : eval_metrics[i] for i in range(len(labels))}\n", " problem_type = None\n", " eval_metric = None\n", " for i in range(len(labels)):\n", " label = labels[i]\n", " path_i = os.path.join(self.path, \"Predictor_\" + str(label))\n", " if problem_types is not None:\n", " problem_type = problem_types[i]\n", " if eval_metrics is not None:\n", " eval_metric = eval_metrics[i]\n", " self.predictors[label] = TabularPredictor(label=label, problem_type=problem_type, eval_metric=eval_metric, path=path_i, **kwargs)\n", "\n", " def fit(self, train_data, tuning_data=None, **kwargs):\n", " \"\"\" Fits a separate TabularPredictor to predict each of the labels.\n", "\n", " Parameters\n", " ----------\n", " train_data, tuning_data : str or pd.DataFrame\n", " See documentation for `TabularPredictor.fit()`.\n", " kwargs :\n", " Arguments passed into the `fit()` call for each TabularPredictor.\n", " \"\"\"\n", " if isinstance(train_data, str):\n", " train_data = TabularDataset(train_data)\n", " if tuning_data is not None and isinstance(tuning_data, str):\n", " tuning_data = TabularDataset(tuning_data)\n", " train_data_og = train_data.copy()\n", " if tuning_data is not None:\n", " tuning_data_og = tuning_data.copy()\n", " else:\n", " tuning_data_og = None\n", " save_metrics = len(self.eval_metrics) == 0\n", " for i in range(len(self.labels)):\n", " label = self.labels[i]\n", " predictor = self.get_predictor(label)\n", " if not self.consider_labels_correlation:\n", " labels_to_drop = [l for l in self.labels if l != label]\n", " else:\n", " labels_to_drop = [self.labels[j] for j in range(i+1, len(self.labels))]\n", " train_data = train_data_og.drop(labels_to_drop, axis=1)\n", " if tuning_data is not None:\n", " tuning_data = tuning_data_og.drop(labels_to_drop, axis=1)\n", " print(f\"Fitting TabularPredictor for label: {label} ...\")\n", " predictor.fit(train_data=train_data, tuning_data=tuning_data, **kwargs)\n", " self.predictors[label] = predictor.path\n", " if save_metrics:\n", " self.eval_metrics[label] = predictor.eval_metric\n", " self.save()\n", "\n", " def predict(self, data, **kwargs):\n", " \"\"\" Returns DataFrame with label columns containing predictions for each label.\n", "\n", " Parameters\n", " ----------\n", " data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n", " Data to make predictions for. If label columns are present in this data, they will be ignored. See documentation for `TabularPredictor.predict()`.\n", " kwargs :\n", " Arguments passed into the predict() call for each TabularPredictor.\n", " \"\"\"\n", " return self._predict(data, as_proba=False, **kwargs)\n", "\n", " def predict_proba(self, data, **kwargs):\n", " \"\"\" Returns dict where each key is a label and the corresponding value is the `predict_proba()` output for just that label.\n", "\n", " Parameters\n", " ----------\n", " data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n", " Data to make predictions for. See documentation for `TabularPredictor.predict()` and `TabularPredictor.predict_proba()`.\n", " kwargs :\n", " Arguments passed into the `predict_proba()` call for each TabularPredictor (also passed into a `predict()` call).\n", " \"\"\"\n", " return self._predict(data, as_proba=True, **kwargs)\n", "\n", " def evaluate(self, data, **kwargs):\n", " \"\"\" Returns dict where each key is a label and the corresponding value is the `evaluate()` output for just that label.\n", "\n", " Parameters\n", " ----------\n", " data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n", " Data to evalate predictions of all labels for, must contain all labels as columns. See documentation for `TabularPredictor.evaluate()`.\n", " kwargs :\n", " Arguments passed into the `evaluate()` call for each TabularPredictor (also passed into the `predict()` call).\n", " \"\"\"\n", " data = self._get_data(data)\n", " eval_dict = {}\n", " for label in self.labels:\n", " print(f\"Evaluating TabularPredictor for label: {label} ...\")\n", " predictor = self.get_predictor(label)\n", " eval_dict[label] = predictor.evaluate(data, **kwargs)\n", " if self.consider_labels_correlation:\n", " data[label] = predictor.predict(data, **kwargs)\n", " return eval_dict\n", "\n", " def save(self):\n", " \"\"\" Save MultilabelPredictor to disk. \"\"\"\n", " for label in self.labels:\n", " if not isinstance(self.predictors[label], str):\n", " self.predictors[label] = self.predictors[label].path\n", " save_pkl.save(path=os.path.join(self.path, self.multi_predictor_file), object=self)\n", " print(f\"MultilabelPredictor saved to disk. Load with: MultilabelPredictor.load('{self.path}')\")\n", "\n", " @classmethod\n", " def load(cls, path):\n", " \"\"\" Load MultilabelPredictor from disk `path` previously specified when creating this MultilabelPredictor. \"\"\"\n", " path = os.path.expanduser(path)\n", " return load_pkl.load(path=os.path.join(path, cls.multi_predictor_file))\n", "\n", " def get_predictor(self, label):\n", " \"\"\" Returns TabularPredictor which is used to predict this label. \"\"\"\n", " predictor = self.predictors[label]\n", " if isinstance(predictor, str):\n", " return TabularPredictor.load(path=predictor)\n", " return predictor\n", "\n", " def _get_data(self, data):\n", " if isinstance(data, str):\n", " return TabularDataset(data)\n", " return data.copy()\n", "\n", " def _predict(self, data, as_proba=False, **kwargs):\n", " data = self._get_data(data)\n", " if as_proba:\n", " predproba_dict = {}\n", " for label in self.labels:\n", " print(f\"Predicting with TabularPredictor for label: {label} ...\")\n", " predictor = self.get_predictor(label)\n", " if as_proba:\n", " predproba_dict[label] = predictor.predict_proba(data, as_multiclass=True, **kwargs)\n", " data[label] = predictor.predict(data, **kwargs)\n", " if not as_proba:\n", " return data[self.labels]\n", " else:\n", " return predproba_dict" ] }, { "cell_type": "markdown", "id": "f117bbb5", "metadata": {}, "source": [ "## Training\n", "\n", "Let's now apply our multi-label predictor to predict multiple columns in a data table. We first train models to predict each of the labels." ] }, { "cell_type": "code", "execution_count": null, "id": "c3ea2dfc", "metadata": {}, "outputs": [], "source": [ "train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')\n", "subsample_size = 500 # subsample subset of data for faster demo, try setting this to much larger values\n", "train_data = train_data.sample(n=subsample_size, random_state=0)\n", "train_data.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "9192e870", "metadata": {}, "outputs": [], "source": [ "labels = ['education-num','education','class'] # which columns to predict based on the others\n", "problem_types = ['regression','multiclass','binary'] # type of each prediction problem (optional)\n", "eval_metrics = ['mean_absolute_error','accuracy','accuracy'] # metrics used to evaluate predictions for each label (optional)\n", "save_path = 'agModels-predictEducationClass' # specifies folder to store trained models (optional)\n", "\n", "time_limit = 5 # how many seconds to train the TabularPredictor for each label, set much larger in your applications!" ] }, { "cell_type": "code", "execution_count": null, "id": "9968b70e", "metadata": {}, "outputs": [], "source": [ "multi_predictor = MultilabelPredictor(labels=labels, problem_types=problem_types, eval_metrics=eval_metrics, path=save_path)\n", "multi_predictor.fit(train_data, time_limit=time_limit)" ] }, { "cell_type": "markdown", "id": "a6b541c6", "metadata": {}, "source": [ "## Inference and Evaluation\n", "\n", "After training, you can easily use the `MultilabelPredictor` to predict all labels in new data:" ] }, { "cell_type": "code", "execution_count": null, "id": "7a5f5d9b", "metadata": {}, "outputs": [], "source": [ "test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')\n", "test_data = test_data.sample(n=subsample_size, random_state=0)\n", "test_data_nolab = test_data.drop(columns=labels) # unnecessary, just to demonstrate we're not cheating here\n", "test_data_nolab.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "d2ef3acb", "metadata": {}, "outputs": [], "source": [ "multi_predictor = MultilabelPredictor.load(save_path) # unnecessary, just demonstrates how to load previously-trained multilabel predictor from file\n", "\n", "predictions = multi_predictor.predict(test_data_nolab)\n", "print(\"Predictions: \\n\", predictions)" ] }, { "cell_type": "markdown", "id": "4a845bab", "metadata": {}, "source": [ "We can also easily evaluate the performance of our predictions if our new data contain the ground truth labels:" ] }, { "cell_type": "code", "execution_count": null, "id": "ac52d82b", "metadata": {}, "outputs": [], "source": [ "evaluations = multi_predictor.evaluate(test_data)\n", "print(evaluations)\n", "print(\"Evaluated using metrics:\", multi_predictor.eval_metrics)" ] }, { "cell_type": "markdown", "id": "3f30569d", "metadata": {}, "source": [ "## Accessing the TabularPredictor for One Label\n", "\n", "We can also directly work with the `TabularPredictor` for any one of the labels as follows. However we recommend you set `consider_labels_correlation=False` before training if you later plan to use an individual `TabularPredictor` to predict just one label rather than all of the labels predicted by the `MultilabelPredictor`." ] }, { "cell_type": "code", "execution_count": null, "id": "e796708d", "metadata": {}, "outputs": [], "source": [ "predictor_class = multi_predictor.get_predictor('class')\n", "predictor_class.leaderboard()" ] }, { "cell_type": "markdown", "id": "2eac71ef", "metadata": {}, "source": [ "## Tips\n", "\n", "In order to obtain the best predictions, you should generally add the following arguments to `MultilabelPredictor.fit()`:\n", "\n", "1) Specify `eval_metrics` to the metrics you will use to evaluate predictions for each label\n", "\n", "2) Specify `presets='best_quality'` to tell AutoGluon you care about predictive performance more than latency/memory usage, which will utilize stack ensembling when predicting each label.\n", "\n", "\n", "If you find that too much memory/disk is being used, try calling `MultilabelPredictor.fit()` with additional arguments discussed under [\"If you encounter memory issues\" in the In Depth Tutorial](../tabular-indepth.ipynb) or [\"If you encounter disk space issues\"](../tabular-indepth.ipynb).\n", "\n", "If you find inference too slow, you can try the strategies discussed under [\"Accelerating Inference\" in the In Depth Tutorial](../tabular-indepth.ipynb).\n", "In particular, simply try specifying the following preset in `MultilabelPredictor.fit()`: `presets = ['good_quality', 'optimize_for_deployment']`" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }