.. _sec_tabularcustommodel: Adding a custom model to AutoGluon ================================== **Tip**: If you are new to AutoGluon, review :ref:`sec_tabularquick` to learn the basics of the AutoGluon API. This tutorial describes how to add a custom model to AutoGluon that can be trained, hyperparameter-tuned, and ensembled alongside the default models (`default model documentation <../../api/autogluon.tabular.models.html#module-autogluon.tabular.models>`__). In this example, we create a custom Random Forest model for use in AutoGluon. All models in AutoGluon inherit from the AbstractModel class (`AbstractModel source code <../../_modules/autogluon/core/models/abstract/abstract_model.html>`__), and must follow its API to work alongside other models. Note that while this tutorial provides a basic model implementation, this does not cover many aspects that are used in most implemented models. To best understand how to implement more advanced functionality, refer to the `source code <../../api/autogluon.tabular.models.html#module-autogluon.tabular.models>`__ of the following models: =================================================== ===================================================================================================================================================================== Functionality Reference Implementation =================================================== ===================================================================================================================================================================== Respecting time limit / early stopping logic `LGBModel <../../_modules/autogluon/tabular/models/lgb/lgb_model.html#LGBModel>`__ and `RFModel <../../_modules/autogluon/tabular/models/rf/rf_model.html#RFModel>`__ Respecting memory usage limit LGBModel and RFModel Sample weight support LGBModel Validation data and eval_metric usage LGBModel GPU training support LGBModel Save / load logic of non-serializable models `NNFastAiTabularModel <../../_modules/autogluon/tabular/models/fastainn/tabular_nn_fastai.html#NNFastAiTabularModel>`__ Advanced problem type support (Softclass, Quantile) RFModel Text feature type support `TextPredictorModel <../../_modules/autogluon/tabular/models/text_prediction/text_prediction_v1_model.html#TextPredictorModel>`__ Image feature type support `ImagePredictorModel <../../_modules/autogluon/tabular/models/image_prediction/image_predictor.html#ImagePredictorModel>`__ Lazy import of package dependencies LGBModel Custom HPO logic LGBModel =================================================== ===================================================================================================================================================================== Implementing a custom model --------------------------- Here we define the custom model we will be working with for the rest of the tutorial. The most important methods that must be implemented are ``_fit`` and ``_preprocess``. To compare with the official AutoGluon Random Forest implementation, see the `RFModel <../../_modules/autogluon/tabular/models/rf/rf_model.html#RFModel>`__ source code. Follow along with the code comments to better understand how the code works. .. code:: python import numpy as np import pandas as pd from autogluon.core.models import AbstractModel from autogluon.features.generators import LabelEncoderFeatureGenerator class CustomRandomForestModel(AbstractModel): def __init__(self, **kwargs): # Simply pass along kwargs to parent, and init our internal `_feature_generator` variable to None super().__init__(**kwargs) self._feature_generator = None # The `_preprocess` method takes the input data and transforms it to the internal representation usable by the model. # `_preprocess` is called by `preprocess` and is used during model fit and model inference. def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> np.ndarray: print(f'Entering the `_preprocess` method: {len(X)} rows of data (is_train={is_train})') X = super()._preprocess(X, **kwargs) if is_train: # X will be the training data. self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0) self._feature_generator.fit(X=X) if self._feature_generator.features_in: # This converts categorical features to numeric via stateful label encoding. X = X.copy() X[self._feature_generator.features_in] = self._feature_generator.transform(X=X) # Add a fillna call to handle missing values. # Some algorithms will be able to handle NaN values internally (LightGBM). # In those cases, you can simply pass the NaN values into the inner model. # Finally, convert to numpy for optimized memory usage and because sklearn RF works with raw numpy input. return X.fillna(0).to_numpy(dtype=np.float32) # The `_fit` method takes the input training data (and optionally the validation data) and trains the model. def _fit(self, X: pd.DataFrame, # training data y: pd.Series, # training labels # X_val=None, # val data (unused in RF model) # y_val=None, # val labels (unused in RF model) # time_limit=None, # time limit in seconds (ignored in tutorial) **kwargs): # kwargs includes many other potential inputs, refer to AbstractModel documentation for details print('Entering the `_fit` method') # First we import the required dependencies for the model. Note that we do not import them outside of the method. # This enables AutoGluon to be highly extensible and modular. # For an example of best practices when importing model dependencies, refer to LGBModel. from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # Valid self.problem_type values include ['binary', 'multiclass', 'regression', 'quantile', 'softclass'] if self.problem_type in ['regression', 'softclass']: model_cls = RandomForestRegressor else: model_cls = RandomForestClassifier # Make sure to call preprocess on X near the start of `_fit`. # This is necessary because the data is converted via preprocess during predict, and needs to be in the same format as during fit. X = self.preprocess(X, is_train=True) # This fetches the user-specified (and default) hyperparameters for the model. params = self._get_model_params() print(f'Hyperparameters: {params}') # self.model should be set to the trained inner model, so that internally during predict we can call `self.model.predict(...)` self.model = model_cls(**params) self.model.fit(X, y) print('Exiting the `_fit` method') # The `_set_default_params` method defines the default hyperparameters of the model. # User-specified parameters will override these values on a key-by-key basis. def _set_default_params(self): default_params = { 'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, } for param, val in default_params.items(): self._set_default_param_value(param, val) # The `_get_default_auxiliary_params` method defines various model-agnostic parameters such as maximum memory usage and valid input column dtypes. # For most users who build custom models, they will only need to specify the valid/invalid dtypes to the model here. def _get_default_auxiliary_params(self) -> dict: default_auxiliary_params = super()._get_default_auxiliary_params() extra_auxiliary_params = dict( # the total set of raw dtypes are: ['int', 'float', 'category', 'object', 'datetime'] # object feature dtypes include raw text and image paths, which should only be handled by specialized models # datetime raw dtypes are generally converted to int in upstream pre-processing, # so models generally shouldn't need to explicitly support datetime dtypes. valid_raw_types=['int', 'float', 'category'], # Other options include `valid_special_types`, `ignored_type_group_raw`, and `ignored_type_group_special`. # Refer to AbstractModel for more details on available options. ) default_auxiliary_params.update(extra_auxiliary_params) return default_auxiliary_params Loading the data ---------------- Next we will load the data. For this tutorial we will use the adult income dataset because it has a mix of integer, float, and categorical features. .. code:: python from autogluon.tabular import TabularDataset train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv') # can be local CSV file as well, returns Pandas DataFrame test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv') # another Pandas DataFrame label = 'class' # specifies which column do we want to predict train_data = train_data.sample(n=1000, random_state=0) # subsample for faster demo train_data.head(5) .. raw:: html
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country class
6118 51 Private 39264 Some-college 10 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States >50K
23204 58 Private 51662 10th 6 Married-civ-spouse Other-service Wife White Female 0 0 8 United-States <=50K
29590 40 Private 326310 Some-college 10 Married-civ-spouse Craft-repair Husband White Male 0 0 44 United-States <=50K
18116 37 Private 222450 HS-grad 9 Never-married Sales Not-in-family White Male 0 2339 40 El-Salvador <=50K
33964 62 Private 109190 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 15024 0 40 United-States >50K
Training a custom model without TabularPredictor ------------------------------------------------ Below we will demonstrate how to train the model outside `TabularPredictor <../../api/autogluon.predictor.html#module-0>`__. This is useful for debugging and minimizing the amount of code you need to understand while implementing the model. This process is similar to what happens internally when calling fit on ``TabularPredictor``, but is simplified and minimal. If the data was already cleaned (all numeric), then we could call fit directly with the data, but the adult dataset is not. Clean labels ~~~~~~~~~~~~ The first step to making the input data as valid input to the model is to clean the labels. Currently, they are strings, but we need to convert them to numeric values (0 and 1) for binary classification. Luckily, AutoGluon already implements logic to both detect that this is binary classification (via ``infer_problem_type``), and a converter to map the labels to 0 and 1 (``LabelCleaner``): .. code:: python # Separate features and labels X = train_data.drop(columns=[label]) y = train_data[label] X_test = test_data.drop(columns=[label]) y_test = test_data[label] from autogluon.core.data import LabelCleaner from autogluon.core.utils import infer_problem_type # Construct a LabelCleaner to neatly convert labels to float/integers during model training/inference, can also use to inverse_transform back to original. problem_type = infer_problem_type(y=y) # Infer problem type (or else specify directly) label_cleaner = LabelCleaner.construct(problem_type=problem_type, y=y) y_clean = label_cleaner.transform(y) print(f'Labels cleaned: {label_cleaner.inv_map}') print(f'inferred problem type as: {problem_type}') print('Cleaned label values:') y_clean.head(5) .. parsed-literal:: :class: output Labels cleaned: {' <=50K': 0, ' >50K': 1} inferred problem type as: binary Cleaned label values: .. parsed-literal:: :class: output 6118 1 23204 0 29590 0 18116 0 33964 1 Name: class, dtype: uint8 Clean features ~~~~~~~~~~~~~~ Next, we need to clean the features. Currently, features like ‘workclass’ are object dtypes (strings), but we actually want to use them as categorical features. Most models won’t accept string inputs, so we need to convert the strings to numbers. AutoGluon contains an entire module dedicated to cleaning, transforming, and generating features called `autogluon.features <../../api/autogluon.features.html>`__. Here we will use the same feature generator used internally by ``TabularPredictor`` to convert the object dtypes to categorical and minimize memory usage. .. code:: python from autogluon.common.utils.log_utils import set_logger_verbosity from autogluon.features.generators import AutoMLPipelineFeatureGenerator set_logger_verbosity(2) # Set logger so more detailed logging is shown for tutorial feature_generator = AutoMLPipelineFeatureGenerator() X_clean = feature_generator.fit_transform(X) X_clean.head(5) .. parsed-literal:: :class: output Fitting AutoMLPipelineFeatureGenerator... Available Memory: 31594.43 MB Train Data (Original) Memory Usage: 0.59 MB (0.0% of available memory) Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features. Stage 1 Generators: Fitting AsTypeFeatureGenerator... Note: Converting 1 features to boolean dtype as they only contain 2 unique values. Stage 2 Generators: Fitting FillNaFeatureGenerator... Stage 3 Generators: Fitting IdentityFeatureGenerator... Fitting CategoryFeatureGenerator... Fitting CategoryMemoryMinimizeFeatureGenerator... Stage 4 Generators: Fitting DropUniqueFeatureGenerator... Types of features in original data (raw dtype, special dtypes): ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] Types of features in processed data (raw dtype, special dtypes): ('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('int', ['bool']) : 1 | ['sex'] 0.1s = Fit runtime 14 features in original data used to generate 14 features in processed data. Train Data (Processed) Memory Usage: 0.07 MB (0.0% of available memory) .. raw:: html
age fnlwgt education-num sex capital-gain capital-loss hours-per-week workclass education marital-status occupation relationship race native-country
6118 51 39264 10 0 0 0 40 3 14 1 4 5 4 24
23204 58 51662 6 0 0 0 8 3 0 1 8 5 4 24
29590 40 326310 10 1 0 0 44 3 14 1 3 0 4 24
18116 37 222450 9 1 0 2339 40 3 11 3 12 1 4 6
33964 62 109190 13 1 15024 0 40 3 9 1 4 0 4 24
`AutoMLPipelineFeatureGenerator <../../api/autogluon.features.html#automlpipelinefeaturegenerator>`__ does not fill missing values for numeric features nor does it rescale the values of numeric features or one-hot encode categoricals. If a model requires these operations, you’ll need to add these operations into your ``_preprocess`` method, and may find some FeatureGenerator classes useful for this. Fit model ~~~~~~~~~ We are now ready to fit the model with the cleaned features and labels. .. code:: python custom_model = CustomRandomForestModel() # We could also specify hyperparameters to override defaults # custom_model = CustomRandomForestModel(hyperparameters={'max_depth': 10}) custom_model.fit(X=X_clean, y=y_clean) # Fit custom model # To save to disk and load the model, do the following: # load_path = custom_model.path # custom_model.save() # del custom_model # custom_model = CustomRandomForestModel.load(path=load_path) .. parsed-literal:: :class: output Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014228/CustomRandomForestModel/" Warning: No path was specified for model, defaulting to: AutogluonModels/ag-20221213_014228/ AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = 1, class 0 = 0 Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Entering the `_fit` method Entering the `_preprocess` method: 1000 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} Exiting the `_fit` method .. parsed-literal:: :class: output <__main__.CustomRandomForestModel at 0x7f90e2983a00> Predict with trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that the model is fit, we can make predictions on new data. Remember that we need to perform the same data and label transformations to the new data as we did to the training data. .. code:: python # Prepare test data X_test_clean = feature_generator.transform(X_test) y_test_clean = label_cleaner.transform(y_test) X_test.head(5) .. raw:: html
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country
0 31 Private 169085 11th 7 Married-civ-spouse Sales Wife White Female 0 0 20 United-States
1 17 Self-emp-not-inc 226203 12th 8 Never-married Sales Own-child White Male 0 0 45 United-States
2 47 Private 54260 Assoc-voc 11 Married-civ-spouse Exec-managerial Husband White Male 0 1887 60 United-States
3 21 Private 176262 Some-college 10 Never-married Exec-managerial Own-child White Female 0 0 30 United-States
4 17 Private 241185 12th 8 Never-married Prof-specialty Own-child White Male 0 0 20 United-States
Get raw predictions from the test data .. code:: python y_pred = custom_model.predict(X_test_clean) print(y_pred[:5]) .. parsed-literal:: :class: output Entering the `_preprocess` method: 9769 rows of data (is_train=False) [0, 0, 1, 0, 0] Note that these predictions are of the positive class (whichever class was inferred to 1). To get more interpretable results, do the following: .. code:: python y_pred_orig = label_cleaner.inverse_transform(y_pred) y_pred_orig.head(5) .. parsed-literal:: :class: output 0 <=50K 1 <=50K 2 >50K 3 <=50K 4 <=50K dtype: object Score with trained model ~~~~~~~~~~~~~~~~~~~~~~~~ By default, the model has an eval_metric specific to the problem_type. For binary classification, it uses accuracy. We can get the accuracy score of the model by doing the following: .. code:: python score = custom_model.score(X_test_clean, y_test_clean) print(f'Test score ({custom_model.eval_metric.name}) = {score}') .. parsed-literal:: :class: output Entering the `_preprocess` method: 9769 rows of data (is_train=False) Test score (accuracy) = 0.8424608455317842 Training a bagged custom model without TabularPredictor ------------------------------------------------------- Some of the more advanced functionality in AutoGluon such as bagging can be done very easily to models once they inherit from AbstractModel. You can even bag your custom model in a couple lines of code. This is a quick way to get quality improvements on nearly any model: .. code:: python from autogluon.core.models import BaggedEnsembleModel bagged_custom_model = BaggedEnsembleModel(CustomRandomForestModel()) # Parallel folding currently doesn't work with a class not defined in a separate module because of underlying pickle serialization issue # You don't need this following line if you put your custom model in a separate file and import it. bagged_custom_model.params['fold_fitting_strategy'] = 'sequential_local' bagged_custom_model.fit(X=X_clean, y=y_clean, k_fold=10) # Perform 10-fold bagging bagged_score = bagged_custom_model.score(X_test_clean, y_test_clean) print(f'Test score ({bagged_custom_model.eval_metric.name}) = {bagged_score} (bagged)') print(f'Bagging increased model accuracy by {round(bagged_score - score, 4) * 100}%!') .. parsed-literal:: :class: output Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014230/CustomRandomForestModel/" Warning: No path was specified for model, defaulting to: AutogluonModels/ag-20221213_014230/ Warning: No name was specified for model, defaulting to class name: BaggedEnsembleModel No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014230/BaggedEnsembleModel/" Warning: No path was specified for model, defaulting to: AutogluonModels/ag-20221213_014230/ AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = 1, class 0 = 0 AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = 1, class 0 = 0 Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model 's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. Fitting 10 child models (S1F1 - S1F10) | Fitting with SequentialLocalFoldFittingStrategy AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F1's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F2's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F3's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F4's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F5's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F6's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [0, 1] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F7's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F8's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F9's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [1, 0] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Model S1F10's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init. .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 900 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} Exiting the `_fit` method Entering the `_preprocess` method: 100 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Test score (accuracy) = 0.8436892210052206 (bagged) Bagging increased model accuracy by 0.12%! Note that the bagged model trained 10 CustomRandomForestModels on different splits of the training data. When making a prediction, the bagged model averages the predictions from these 10 models. Training a custom model with TabularPredictor --------------------------------------------- While not using `TabularPredictor <../../api/autogluon.predictor.html#module-0>`__ allows us to simplify the amount of code we need to worry about while developing and debugging our model, eventually we want to leverage TabularPredictor to get the most out of our model. The code to train the model from the raw data is very simple when using TabularPredictor. There is no need to specify a LabelCleaner, FeatureGenerator, or a validation set, all of that is handled internally. Here we train 3 CustomRandomForestModel with different hyperparameters. .. code:: python from autogluon.tabular import TabularPredictor # custom_hyperparameters = {CustomRandomForestModel: {}} # train 1 CustomRandomForestModel Model with default hyperparameters custom_hyperparameters = {CustomRandomForestModel: [{}, {'max_depth': 10}, {'max_features': 0.9, 'max_depth': 20}]} # Train 3 CustomRandomForestModel with different hyperparameters predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters) .. parsed-literal:: :class: output No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014239/" Beginning AutoGluon training ... AutoGluon will save models to "AutogluonModels/ag-20221213_014239/" AutoGluon Version: 0.6.1b20221213 Python Version: 3.8.10 Operating System: Linux Platform Machine: x86_64 Platform Version: #1 SMP Tue Nov 30 00:17:50 UTC 2021 Train Data Rows: 1000 Train Data Columns: 14 Label Column: class Preprocessing data ... AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [' >50K', ' <=50K'] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class. To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init. Using Feature Generators to preprocess the data ... Fitting AutoMLPipelineFeatureGenerator... Available Memory: 31555.11 MB Train Data (Original) Memory Usage: 0.59 MB (0.0% of available memory) Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features. Stage 1 Generators: Fitting AsTypeFeatureGenerator... Note: Converting 1 features to boolean dtype as they only contain 2 unique values. Stage 2 Generators: Fitting FillNaFeatureGenerator... Stage 3 Generators: Fitting IdentityFeatureGenerator... Fitting CategoryFeatureGenerator... Fitting CategoryMemoryMinimizeFeatureGenerator... Stage 4 Generators: Fitting DropUniqueFeatureGenerator... Types of features in original data (raw dtype, special dtypes): ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] Types of features in processed data (raw dtype, special dtypes): ('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('int', ['bool']) : 1 | ['sex'] 0.1s = Fit runtime 14 features in original data used to generate 14 features in processed data. Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory) Data preprocessing and feature engineering runtime = 0.09s ... AutoGluon will gauge predictive performance using evaluation metric: 'accuracy' To change this, specify the eval_metric parameter of Predictor() Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200 Custom Model Type Detected: Custom Model Type Detected: Custom Model Type Detected: Fitting 3 L1 models ... Fitting model: CustomRandomForestModel ... .. parsed-literal:: :class: output Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0} .. parsed-literal:: :class: output 0.84 = Validation score (accuracy) 1.2s = Training runtime 0.06s = Validation runtime Fitting model: CustomRandomForestModel_2 ... .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10} .. parsed-literal:: :class: output 0.845 = Validation score (accuracy) 1.19s = Training runtime 0.06s = Validation runtime Fitting model: CustomRandomForestModel_3 ... .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_features': 0.9, 'max_depth': 20} .. parsed-literal:: :class: output 0.835 = Validation score (accuracy) 1.19s = Training runtime 0.06s = Validation runtime Fitting model: WeightedEnsemble_L2 ... .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) .. parsed-literal:: :class: output 0.855 = Validation score (accuracy) 0.11s = Training runtime 0.0s = Validation runtime AutoGluon training complete, total runtime = 4.65s ... Best model: "WeightedEnsemble_L2" TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20221213_014239/") Predictor leaderboard ~~~~~~~~~~~~~~~~~~~~~ Here we show the stats of each of the models trained. Notice that a WeightedEnsemble model was also trained. This model tries to combine the predictions of the other models to get a better validation score via ensembling. .. code:: python predictor.leaderboard(test_data, silent=True) .. parsed-literal:: :class: output Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) .. raw:: html
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CustomRandomForestModel_2 0.846044 0.845 0.138547 0.062872 1.193779 0.138547 0.062872 1.193779 1 True 2
1 CustomRandomForestModel 0.840414 0.840 0.141398 0.061999 1.199242 0.141398 0.061999 1.199242 1 True 1
2 WeightedEnsemble_L2 0.839390 0.855 0.421377 0.188749 3.698206 0.003341 0.000844 0.113758 2 True 4
3 CustomRandomForestModel_3 0.828744 0.835 0.138091 0.063035 1.191428 0.138091 0.063035 1.191428 1 True 3
Predict with fit predictor ~~~~~~~~~~~~~~~~~~~~~~~~~~ Here we predict with the fit predictor. This will automatically use the best model (the one with highest score_val) to predict. .. code:: python y_pred = predictor.predict(test_data) # y_pred = predictor.predict(test_data, model='CustomRandomForestModel_3') # If we want a specific model to predict y_pred.head(5) .. parsed-literal:: :class: output Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) Entering the `_preprocess` method: 9769 rows of data (is_train=False) .. parsed-literal:: :class: output 0 <=50K 1 <=50K 2 >50K 3 <=50K 4 <=50K Name: class, dtype: object Hyperparameter tuning a custom model with TabularPredictor ---------------------------------------------------------- We can easily hyperparameter tune custom models by specifying a hyperparameter search space in-place of exact values. Here we hyperparameter tune the custom model for 20 seconds: .. code:: python from autogluon.core.space import Categorical, Int, Real custom_hyperparameters_hpo = {CustomRandomForestModel: { 'max_depth': Int(lower=5, upper=30), 'max_features': Real(lower=0.1, upper=1.0), 'criterion': Categorical('gini', 'entropy'), }} # Hyperparameter tune CustomRandomForestModel for 20 seconds predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters_hpo, hyperparameter_tune_kwargs='auto', # enables HPO time_limit=20) .. parsed-literal:: :class: output No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014245/" Warning: hyperparameter tuning is currently experimental and may cause the process to hang. Beginning AutoGluon training ... Time limit = 20s AutoGluon will save models to "AutogluonModels/ag-20221213_014245/" AutoGluon Version: 0.6.1b20221213 Python Version: 3.8.10 Operating System: Linux Platform Machine: x86_64 Platform Version: #1 SMP Tue Nov 30 00:17:50 UTC 2021 Train Data Rows: 1000 Train Data Columns: 14 Label Column: class Preprocessing data ... AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [' >50K', ' <=50K'] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class. To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init. Using Feature Generators to preprocess the data ... Fitting AutoMLPipelineFeatureGenerator... Available Memory: 31553.64 MB Train Data (Original) Memory Usage: 0.59 MB (0.0% of available memory) Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features. Stage 1 Generators: Fitting AsTypeFeatureGenerator... Note: Converting 1 features to boolean dtype as they only contain 2 unique values. Stage 2 Generators: Fitting FillNaFeatureGenerator... Stage 3 Generators: Fitting IdentityFeatureGenerator... Fitting CategoryFeatureGenerator... Fitting CategoryMemoryMinimizeFeatureGenerator... Stage 4 Generators: Fitting DropUniqueFeatureGenerator... Types of features in original data (raw dtype, special dtypes): ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] Types of features in processed data (raw dtype, special dtypes): ('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('int', ['bool']) : 1 | ['sex'] 0.1s = Fit runtime 14 features in original data used to generate 14 features in processed data. Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory) Data preprocessing and feature engineering runtime = 0.1s ... AutoGluon will gauge predictive performance using evaluation metric: 'accuracy' To change this, specify the eval_metric parameter of Predictor() Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200 Custom Model Type Detected: Fitting 1 L1 models ... Hyperparameter tuning model: CustomRandomForestModel ... Tuning model for up to 17.91s of the 19.9s of remaining time. .. parsed-literal:: :class: output Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.1, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 20, 'max_features': 0.7436704297351775, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 8, 'max_features': 0.8625265649057129, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 11, 'max_features': 0.15104167958569886, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.8125525342743981, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 19, 'max_features': 0.6112401049845391, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.16393245237809825, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 25, 'max_features': 0.11819655769629316, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.8003410758548655, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.9807565080094875, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.5153314260276387, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 24, 'max_features': 0.20644698328203992, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.22901795866814179, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.5696634895750645, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 28, 'max_features': 0.3381000508941643, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 23, 'max_features': 0.5105352989948937, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.11691082039271963, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.6508861504501793, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.9493732706631618, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 15, 'max_features': 0.42355711051640743, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.7278680763345383, 'criterion': 'gini'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.7000900439011009, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 16, 'max_features': 0.2893443049664568, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.38388551583176544, 'criterion': 'entropy'} Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) .. parsed-literal:: :class: output Stopping HPO to satisfy time limit... Fitted model: CustomRandomForestModel/T1 ... 0.805 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T2 ... 0.835 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T3 ... 0.825 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T4 ... 0.855 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T5 ... 0.835 = Validation score (accuracy) 0.6s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T6 ... 0.83 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T7 ... 0.845 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T8 ... 0.845 = Validation score (accuracy) 0.61s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T9 ... 0.84 = Validation score (accuracy) 0.61s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T10 ... 0.845 = Validation score (accuracy) 0.58s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T11 ... 0.85 = Validation score (accuracy) 0.58s = Training runtime 0.07s = Validation runtime Fitted model: CustomRandomForestModel/T12 ... 0.835 = Validation score (accuracy) 0.58s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T13 ... 0.84 = Validation score (accuracy) 0.6s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T14 ... 0.835 = Validation score (accuracy) 0.58s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T15 ... 0.845 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T16 ... 0.85 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T17 ... 0.85 = Validation score (accuracy) 0.58s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T18 ... 0.805 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T19 ... 0.845 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T20 ... 0.835 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T21 ... 0.85 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T22 ... 0.83 = Validation score (accuracy) 0.58s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T23 ... 0.845 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T24 ... 0.845 = Validation score (accuracy) 0.59s = Training runtime 0.06s = Validation runtime Fitted model: CustomRandomForestModel/T25 ... 0.845 = Validation score (accuracy) 0.57s = Training runtime 0.06s = Validation runtime .. parsed-literal:: :class: output Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) Entering the `_preprocess` method: 200 rows of data (is_train=False) .. parsed-literal:: :class: output Fitting model: WeightedEnsemble_L2 ... Training model for up to 19.9s of the -0.74s of remaining time. .. parsed-literal:: :class: output Entering the `_preprocess` method: 200 rows of data (is_train=False) .. parsed-literal:: :class: output 0.86 = Validation score (accuracy) 0.16s = Training runtime 0.0s = Validation runtime AutoGluon training complete, total runtime = 21.52s ... Best model: "WeightedEnsemble_L2" TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20221213_014245/") Predictor leaderboard (HPO) ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The leaderboard for the HPO run will show models with suffix ``'/Tx'`` in their name. This indicates the HPO trial they were performed in. .. code:: python leaderboard_hpo = predictor.leaderboard(silent=True) leaderboard_hpo .. raw:: html
model score_val pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 WeightedEnsemble_L2 0.860 0.124356 1.326905 0.000749 0.157713 2 True 26
1 CustomRandomForestModel/T4 0.855 0.060825 0.591918 0.060825 0.591918 1 True 4
2 CustomRandomForestModel/T21 0.850 0.061973 0.589022 0.061973 0.589022 1 True 21
3 CustomRandomForestModel/T16 0.850 0.062684 0.590520 0.062684 0.590520 1 True 16
4 CustomRandomForestModel/T17 0.850 0.062782 0.577274 0.062782 0.577274 1 True 17
5 CustomRandomForestModel/T11 0.850 0.065084 0.581048 0.065084 0.581048 1 True 11
6 CustomRandomForestModel/T15 0.845 0.060246 0.568683 0.060246 0.568683 1 True 15
7 CustomRandomForestModel/T10 0.845 0.060620 0.582077 0.060620 0.582077 1 True 10
8 CustomRandomForestModel/T7 0.845 0.061639 0.587917 0.061639 0.587917 1 True 7
9 CustomRandomForestModel/T24 0.845 0.061696 0.586394 0.061696 0.586394 1 True 24
10 CustomRandomForestModel/T25 0.845 0.062128 0.568804 0.062128 0.568804 1 True 25
11 CustomRandomForestModel/T8 0.845 0.062556 0.605307 0.062556 0.605307 1 True 8
12 CustomRandomForestModel/T23 0.845 0.062659 0.591935 0.062659 0.591935 1 True 23
13 CustomRandomForestModel/T19 0.845 0.062919 0.566890 0.062919 0.566890 1 True 19
14 CustomRandomForestModel/T13 0.840 0.062589 0.601684 0.062589 0.601684 1 True 13
15 CustomRandomForestModel/T9 0.840 0.064390 0.608347 0.064390 0.608347 1 True 9
16 CustomRandomForestModel/T20 0.835 0.060931 0.590263 0.060931 0.590263 1 True 20
17 CustomRandomForestModel/T2 0.835 0.061561 0.586263 0.061561 0.586263 1 True 2
18 CustomRandomForestModel/T12 0.835 0.061966 0.584415 0.061966 0.584415 1 True 12
19 CustomRandomForestModel/T14 0.835 0.062724 0.579637 0.062724 0.579637 1 True 14
20 CustomRandomForestModel/T5 0.835 0.063126 0.597193 0.063126 0.597193 1 True 5
21 CustomRandomForestModel/T22 0.830 0.061154 0.575902 0.061154 0.575902 1 True 22
22 CustomRandomForestModel/T6 0.830 0.061339 0.570987 0.061339 0.570987 1 True 6
23 CustomRandomForestModel/T3 0.825 0.063802 0.594073 0.063802 0.594073 1 True 3
24 CustomRandomForestModel/T1 0.805 0.060655 0.565793 0.060655 0.565793 1 True 1
25 CustomRandomForestModel/T18 0.805 0.063372 0.565666 0.063372 0.565666 1 True 18
Getting the hyperparameters of a trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Let’s get the hyperparameters of the model with the highest validation score. .. code:: python best_model_name = leaderboard_hpo[leaderboard_hpo['stack_level'] == 1]['model'].iloc[0] predictor_info = predictor.info() best_model_info = predictor_info['model_info'][best_model_name] print(best_model_info) print(f'Best Model Hyperparameters ({best_model_name}):') print(best_model_info['hyperparameters']) .. parsed-literal:: :class: output {'name': 'CustomRandomForestModel/T4', 'model_type': 'CustomRandomForestModel', 'problem_type': 'binary', 'eval_metric': 'accuracy', 'stopping_metric': 'accuracy', 'fit_time': 0.5919175148010254, 'num_classes': 2, 'quantile_levels': None, 'predict_time': 0.06082510948181152, 'val_score': 0.855, 'hyperparameters': {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}, 'hyperparameters_fit': {}, 'hyperparameters_nondefault': ['max_depth', 'max_features', 'criterion', 'n_estimators', 'n_jobs', 'random_state'], 'ag_args_fit': {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None}, 'num_features': 14, 'features': ['age', 'fnlwgt', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country'], 'feature_metadata': , 'memory_size': 4331673, 'compile_time': None} Best Model Hyperparameters (CustomRandomForestModel/T4): {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'} Training a custom model alongside other models with TabularPredictor -------------------------------------------------------------------- Finally, we will train the custom model (with tuned hyperparameters) alongside the default AutoGluon models. All this requires is getting the hyperparameter dictionary of the default models via ``get_hyperparameter_config``, and adding CustomRandomForestModel as a key. .. code:: python from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config # Now we can add the custom model with tuned hyperparameters to be trained alongside the default models: custom_hyperparameters = get_hyperparameter_config('default') custom_hyperparameters[CustomRandomForestModel] = best_model_info['hyperparameters'] print(custom_hyperparameters) .. parsed-literal:: :class: output {'NN_TORCH': {}, 'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'], 'CAT': {}, 'XGB': {}, 'FASTAI': {}, 'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}], : {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}} .. code:: python predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters) # Train the default models plus a single tuned CustomRandomForestModel # predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters, presets='best_quality') # We can even use the custom model in a multi-layer stack ensemble predictor.leaderboard(test_data, silent=True) .. parsed-literal:: :class: output No path specified. Models will be saved in: "AutogluonModels/ag-20221213_014308/" Beginning AutoGluon training ... AutoGluon will save models to "AutogluonModels/ag-20221213_014308/" AutoGluon Version: 0.6.1b20221213 Python Version: 3.8.10 Operating System: Linux Platform Machine: x86_64 Platform Version: #1 SMP Tue Nov 30 00:17:50 UTC 2021 Train Data Rows: 1000 Train Data Columns: 14 Label Column: class Preprocessing data ... AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed). 2 unique label values: [' >50K', ' <=50K'] If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression']) Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class. To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init. Using Feature Generators to preprocess the data ... Fitting AutoMLPipelineFeatureGenerator... Available Memory: 31532.41 MB Train Data (Original) Memory Usage: 0.59 MB (0.0% of available memory) Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features. Stage 1 Generators: Fitting AsTypeFeatureGenerator... Note: Converting 1 features to boolean dtype as they only contain 2 unique values. Stage 2 Generators: Fitting FillNaFeatureGenerator... Stage 3 Generators: Fitting IdentityFeatureGenerator... Fitting CategoryFeatureGenerator... Fitting CategoryMemoryMinimizeFeatureGenerator... Stage 4 Generators: Fitting DropUniqueFeatureGenerator... Types of features in original data (raw dtype, special dtypes): ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] Types of features in processed data (raw dtype, special dtypes): ('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...] ('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...] ('int', ['bool']) : 1 | ['sex'] 0.1s = Fit runtime 14 features in original data used to generate 14 features in processed data. Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory) Data preprocessing and feature engineering runtime = 0.09s ... AutoGluon will gauge predictive performance using evaluation metric: 'accuracy' To change this, specify the eval_metric parameter of Predictor() Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200 Custom Model Type Detected: Fitting 14 L1 models ... Fitting model: KNeighborsUnif ... 0.725 = Validation score (accuracy) 0.6s = Training runtime 0.01s = Validation runtime Fitting model: KNeighborsDist ... 0.71 = Validation score (accuracy) 0.6s = Training runtime 0.01s = Validation runtime Fitting model: LightGBMXT ... 0.85 = Validation score (accuracy) 1.31s = Training runtime 0.01s = Validation runtime Fitting model: LightGBM ... 0.84 = Validation score (accuracy) 0.92s = Training runtime 0.01s = Validation runtime Fitting model: RandomForestGini ... 0.845 = Validation score (accuracy) 1.12s = Training runtime 0.06s = Validation runtime Fitting model: RandomForestEntr ... 0.835 = Validation score (accuracy) 1.13s = Training runtime 0.06s = Validation runtime Fitting model: CatBoost ... 0.86 = Validation score (accuracy) 2.48s = Training runtime 0.01s = Validation runtime Fitting model: ExtraTreesGini ... 0.82 = Validation score (accuracy) 1.12s = Training runtime 0.06s = Validation runtime Fitting model: ExtraTreesEntr ... 0.82 = Validation score (accuracy) 1.12s = Training runtime 0.06s = Validation runtime Fitting model: NeuralNetFastAI ... No improvement since epoch 7: early stopping 0.86 = Validation score (accuracy) 3.29s = Training runtime 0.01s = Validation runtime Fitting model: XGBoost ... 0.85 = Validation score (accuracy) 0.3s = Training runtime 0.01s = Validation runtime Fitting model: NeuralNetTorch ... 0.85 = Validation score (accuracy) 3.42s = Training runtime 0.01s = Validation runtime Fitting model: LightGBMLarge ... 0.815 = Validation score (accuracy) 1.11s = Training runtime 0.01s = Validation runtime Fitting model: CustomRandomForestModel ... .. parsed-literal:: :class: output Entering the `_fit` method Entering the `_preprocess` method: 800 rows of data (is_train=True) Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'} .. parsed-literal:: :class: output 0.855 = Validation score (accuracy) 0.6s = Training runtime 0.07s = Validation runtime Fitting model: WeightedEnsemble_L2 ... .. parsed-literal:: :class: output Exiting the `_fit` method Entering the `_preprocess` method: 200 rows of data (is_train=False) .. parsed-literal:: :class: output 0.88 = Validation score (accuracy) 0.38s = Training runtime 0.0s = Validation runtime AutoGluon training complete, total runtime = 20.19s ... Best model: "WeightedEnsemble_L2" TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20221213_014308/") .. parsed-literal:: :class: output Entering the `_preprocess` method: 9769 rows of data (is_train=False) .. raw:: html
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CatBoost 0.852902 0.860 0.016674 0.005907 2.476585 0.016674 0.005907 2.476585 1 True 7
1 WeightedEnsemble_L2 0.852083 0.880 0.369239 0.149117 5.801734 0.004375 0.000788 0.382812 2 True 15
2 LightGBMXT 0.850752 0.850 0.017225 0.005951 1.313833 0.017225 0.005951 1.313833 1 True 3
3 XGBoost 0.850036 0.850 0.039560 0.007545 0.303070 0.039560 0.007545 0.303070 1 True 11
4 NeuralNetFastAI 0.841437 0.860 0.158014 0.014076 3.293159 0.158014 0.014076 3.293159 1 True 10
5 LightGBM 0.841335 0.840 0.013739 0.005840 0.918098 0.013739 0.005840 0.918098 1 True 4
6 RandomForestGini 0.839492 0.845 0.144891 0.062550 1.118562 0.144891 0.062550 1.118562 1 True 5
7 RandomForestEntr 0.838162 0.835 0.142712 0.063799 1.126035 0.142712 0.063799 1.126035 1 True 6
8 NeuralNetTorch 0.836524 0.850 0.059845 0.013888 3.421460 0.059845 0.013888 3.421460 1 True 12
9 CustomRandomForestModel 0.835091 0.855 0.142135 0.065328 0.599762 0.142135 0.065328 0.599762 1 True 14
10 LightGBMLarge 0.832122 0.815 0.069972 0.006558 1.109222 0.069972 0.006558 1.109222 1 True 13
11 ExtraTreesGini 0.831303 0.820 0.145072 0.064440 1.119756 0.145072 0.064440 1.119756 1 True 8
12 ExtraTreesEntr 0.829358 0.820 0.152757 0.063709 1.121407 0.152757 0.063709 1.121407 1 True 9
13 KNeighborsUnif 0.744600 0.725 0.032326 0.006918 0.603907 0.032326 0.006918 0.603907 1 True 1
14 KNeighborsDist 0.710922 0.710 0.035227 0.006434 0.602240 0.035227 0.006434 0.602240 1 True 2
Wrapping up ----------- That’s all it takes to add a custom model to AutoGluon. If you create a custom model, consider `submitting a PR `__ so that we can add it officially to AutoGluon! For more tutorials, refer to :ref:`sec_tabularquick` and :ref:`sec_tabularadvanced`. For a tutorial on advanced custom models, refer to :ref:`sec_tabularcustommodeladvanced`