Automated Quick Model Fit#

Open In Colab Open In SageMaker Studio Lab

The purpose of this feature is to provide a quick and easy way to obtain a preliminary understanding of the relationships between the target variable and the independent variables in a dataset.

This functionality automatically splits the training data, fits a simple regression or classification model to the data and generates insights: model performance metrics, feature importance and prediction result insights.

To inspect the prediction quality, a confusion matrix is displayed for classification problems and scatter plot for regression problems. Both representation allow the user to see the difference between actual and predicted values.

The insights highlight two subsets of the model predictions:

  • predictions with the largest classification error. Rows listed in this section are candidates for inspecting why the model made the mistakes

  • predictions with the least distance from the other class. Rows in this category are most ‘undecided’. They are useful as an examples of data which is close to a decision boundary between the classes. The model would benefit from having more data for similar cases.

Classification Example#

We will start with getting titanic dataset and performing a quick one-line overview to get the information.

import pandas as pd
import autogluon.eda.auto as auto

df_train = pd.read_csv('https://autogluon.s3.amazonaws.com/datasets/titanic/train.csv')
df_test = pd.read_csv('https://autogluon.s3.amazonaws.com/datasets/titanic/test.csv')
target_col = 'Survived'

auto.quick_fit(df_train, target_col, show_feature_importance_barplots=True)
No path specified. Models will be saved in: "AutogluonModels/ag-20230302_162334/"

Model Prediction for Survived

../../_images/030a14680db509da414c6ed234b4dcdd0c0565b2e9bb667e1823319a48b4fadf.png

Model Leaderboard

model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBMXT 0.809701 0.856 0.004417 0.004099 0.183184 0.004417 0.004099 0.183184 1 True 1

Feature Importance for Trained Model

importance stddev p_value n p99_high p99_low
Sex 0.112687 0.013033 0.000021 5 0.139522 0.085851
Name 0.055970 0.009140 0.000082 5 0.074789 0.037151
SibSp 0.026119 0.010554 0.002605 5 0.047850 0.004389
Fare 0.012687 0.009730 0.021720 5 0.032721 -0.007348
Embarked 0.011194 0.006981 0.011525 5 0.025567 -0.003179
Age 0.010448 0.003122 0.000853 5 0.016876 0.004020
PassengerId 0.008955 0.005659 0.012022 5 0.020607 -0.002696
Cabin 0.002985 0.006675 0.186950 5 0.016729 -0.010758
Pclass 0.002239 0.005659 0.213159 5 0.013890 -0.009413
Parch 0.001493 0.002044 0.088904 5 0.005701 -0.002716
Ticket 0.000000 0.000000 0.500000 5 0.000000 0.000000
../../_images/7e6416893c55efa6be2dc70daa5be6040ebf6b3625387dd6b5762dedd1d2ace4.png

Rows with the highest prediction error

Rows in this category worth inspecting for the causes of the error

PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked Survived 0 1 error
498 499 1 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female 25.0 1 2 113781 151.5500 C22 C26 S 0 0.046788 0.953212 0.953212
267 268 3 Persson, Mr. Ernst Ulrik male 25.0 1 0 347083 7.7750 NaN S 1 0.932024 0.067976 0.932024
569 570 3 Jonsson, Mr. Carl male 32.0 0 0 350417 7.8542 NaN S 1 0.922265 0.077735 0.922265
283 284 3 Dorking, Mr. Edward Arthur male 19.0 0 0 A/5. 10482 8.0500 NaN S 1 0.921180 0.078820 0.921180
821 822 3 Lulic, Mr. Nikola male 27.0 0 0 315098 8.6625 NaN S 1 0.919709 0.080291 0.919709
301 302 3 McCoy, Mr. Bernard male NaN 2 0 367226 23.2500 NaN Q 1 0.918546 0.081454 0.918546
288 289 2 Hosono, Mr. Masabumi male 42.0 0 0 237798 13.0000 NaN S 1 0.907043 0.092957 0.907043
36 37 3 Mamee, Mr. Hanna male NaN 0 0 2677 7.2292 NaN C 1 0.906803 0.093197 0.906803
127 128 3 Madsen, Mr. Fridtjof Arne male 24.0 0 0 C 17369 7.1417 NaN S 1 0.906605 0.093395 0.906605
391 392 3 Jansson, Mr. Carl Olof male 21.0 0 0 350034 7.7958 NaN S 1 0.905367 0.094633 0.905367

Rows with the least distance vs other class

Rows in this category are the closest to the decision boundary vs the other class and are good candidates for additional labeling

PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked Survived 0 1 score_diff
182 183 3 Asplund, Master. Clarence Gustaf Hugo male 9.0 4 2 347077 31.3875 NaN S 0 0.503872 0.496128 0.007743
475 476 1 Clifford, Mr. George Quincy male NaN 0 0 110465 52.0000 A14 S 0 0.509178 0.490822 0.018356
347 348 3 Davison, Mrs. Thomas Henry (Mary E Finck) female NaN 1 0 386525 16.1000 NaN S 1 0.510786 0.489214 0.021572
192 193 3 Andersen-Jensen, Miss. Carla Christine Nielsine female 19.0 1 0 350046 7.8542 NaN S 1 0.512167 0.487833 0.024334
330 331 3 McCoy, Miss. Agnes female NaN 2 0 367226 23.2500 NaN Q 1 0.478502 0.521498 0.042996
572 573 1 Flynn, Mr. John Irwin ("Irving") male 36.0 0 0 PC 17474 26.3875 E25 S 1 0.478234 0.521766 0.043532
792 793 3 Sage, Miss. Stella Anna female NaN 8 2 CA. 2343 69.5500 NaN S 0 0.525041 0.474959 0.050082
172 173 3 Johnson, Miss. Eleanor Ileen female 1.0 1 1 347742 11.1333 NaN S 1 0.526793 0.473207 0.053585
328 329 3 Goldsmith, Mrs. Frank John (Emily Alice Brown) female 31.0 1 1 363291 20.5250 NaN S 1 0.531574 0.468426 0.063149
593 594 3 Bourke, Miss. Mary female NaN 0 2 364848 7.7500 NaN Q 0 0.463840 0.536160 0.072319

Regression Example#

In the previous section we tried a classification example. Let’s try a regression. It has a few differences. We are also going to store the fitted model by specifying return_state and save_model_to_state parameters. This will allow us to use the model to predict test values later.

It is a large dataset, so we’ll keep only a few columns for this tutorial.

df_train = pd.read_csv('https://autogluon.s3.amazonaws.com/datasets/AmesHousingPriceRegression/train_data.csv')
df_test = pd.read_csv('https://autogluon.s3.amazonaws.com/datasets/AmesHousingPriceRegression/test_data.csv')
target_col = 'SalePrice'

keep_cols = [
  'Overall.Qual', 'Gr.Liv.Area', 'Neighborhood', 'Total.Bsmt.SF', 'BsmtFin.SF.1',
  'X1st.Flr.SF', 'Bsmt.Qual', 'Garage.Cars', 'Half.Bath', 'Year.Remod.Add', target_col
]

df_train = df_train[[c for c in df_train.columns if c in keep_cols]][:500]
df_test = df_test[[c for c in df_test.columns if c in keep_cols]][:500]


state = auto.quick_fit(df_train, target_col, return_state=True, save_model_to_state=True)
No path specified. Models will be saved in: "AutogluonModels/ag-20230302_162336/"

Model Prediction for SalePrice

../../_images/b95bb5791c028ee87e8b4ac069ff0cf3a48e2ef62115196653b99a98ae89bc55.png

Model Leaderboard

model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 LightGBMXT -29100.820216 -31075.785774 0.003217 0.003058 0.126165 0.003217 0.003058 0.126165 1 True 1

Feature Importance for Trained Model

importance stddev p_value n p99_high p99_low
Overall.Qual 16126.273271 1376.470881 0.000006 5 18960.445840 13292.100702
Gr.Liv.Area 8862.693281 480.183424 0.000001 5 9851.397587 7873.988974
Total.Bsmt.SF 5299.844900 870.222500 0.000084 5 7091.645055 3508.044746
Garage.Cars 4472.147453 660.340484 0.000055 5 5831.797636 3112.497270
X1st.Flr.SF 3804.848692 692.065035 0.000126 5 5229.820166 2379.877219
BsmtFin.SF.1 3725.145846 369.988099 0.000012 5 4486.956454 2963.335237
Year.Remod.Add 3562.868687 1081.770172 0.000906 5 5790.248423 1335.488951
Half.Bath 3020.213571 1041.365031 0.001457 5 5164.398563 876.028580
Neighborhood 624.378438 297.685200 0.004689 5 1237.316379 11.440497
Bsmt.Qual 0.000000 0.000000 0.500000 5 0.000000 0.000000

Rows with the highest prediction error

Rows in this category worth inspecting for the causes of the error

Neighborhood Overall.Qual Year.Remod.Add Bsmt.Qual BsmtFin.SF.1 Total.Bsmt.SF X1st.Flr.SF Gr.Liv.Area Half.Bath Garage.Cars SalePrice SalePrice_pred error
134 Edwards 6 1966 Gd 0.0 697.0 1575 2201 0 2.0 274970 150482.062500 124487.937500
90 Timber 10 2007 Ex 0.0 1824.0 1824 1824 0 3.0 392000 277549.312500 114450.687500
468 NridgHt 9 2003 Ex 1972.0 2452.0 2452 2452 0 3.0 445000 344158.687500 100841.312500
45 NridgHt 9 2006 Ex 0.0 1704.0 1722 2758 1 3.0 418000 327791.531250 90208.468750
118 Somerst 7 2006 Gd 788.0 960.0 960 2318 1 2.0 294323 218767.953125 75555.046875
318 Crawfor 7 2002 Gd 1406.0 1902.0 1902 1902 0 2.0 335000 265314.125000 69685.875000
26 Mitchel 5 2006 NaN 0.0 0.0 1771 1771 0 2.0 115000 179189.171875 64189.171875
233 NoRidge 8 2000 Gd 655.0 1145.0 1145 2198 1 3.0 250000 311554.031250 61554.031250
322 NoRidge 8 1993 Gd 1129.0 1390.0 1402 2225 1 3.0 285000 341849.781250 56849.781250
340 ClearCr 7 2005 Gd 226.0 1385.0 1363 1363 0 2.0 241500 185450.109375 56049.890625

Using a fitted model#

Now let’s get the model from state, perform the prediction on df_test and quickly visualize the results using auto.analyze_interaction() tool:

model = state.model
y_pred = model.predict(df_test)
auto.analyze_interaction(
    train_data=pd.DataFrame({'SalePrice_Pred': y_pred}), 
    x='SalePrice_Pred', 
    fit_distributions=['johnsonsu', 'norm', 'exponnorm']
)
../../_images/bca67d0066d27891bd724e5008e5c4e6d7f194c4168ac167f7fa3fb8fb924bfe.png