Adding a custom time series forecasting model

Open In Colab Open In SageMaker Studio Lab

This tutorial describes how to add a custom forecasting model that can be trained, hyperparameter-tuned, and ensembled alongside the default forecasting models.

As an example, we will implement an AutoGluon wrapper for the NHITS model from the NeuralForecast library.

This tutorial consists of the following sections:

  1. Implementing the model wrapper.

  2. Loading & preprocessing the dataset used for model development.

  3. Using the custom model in standalone mode.

  4. Using the custom model inside the TimeSeriesPredictor.

Warning

This tutorial is designed for advanced AutoGluon users.

Custom model implementations rely heavily on the private of API of AutoGluon that might change over time. For this reason, it might be necessary to update your custom model implementations as you upgrade to new versions of AutoGluon.

First, we install the NeuralForecast library that contains the implementation of the custom model used in this tutorial.

pip install -q neuralforecast==2.0
Note: you may need to restart the kernel to use updated packages.

Implement the custom model

To implement a custom model we need to create a subclass of the AbstractTimeSeriesModel class. This subclass must implement two methods: _fit and _predict. For models that require a custom preprocessing logic (e.g., to handle missing values), we also need to implement the preprocess method.

Please have a look at the following code and read the comments to understand the different components of the custom model wrapper.

import logging
import pprint
from typing import Optional, Tuple

import pandas as pd

from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
from autogluon.timeseries.utils.warning_filters import warning_filter

# Optional - disable annoying PyTorch-Lightning loggers
for logger_name in [
    "lightning.pytorch.utilities.rank_zero",
    "pytorch_lightning.accelerators.cuda",
    "lightning_fabric.utilities.seed",
]:
    logging.getLogger(logger_name).setLevel(logging.ERROR)


class NHITSModel(AbstractTimeSeriesModel):
    """AutoGluon-compatible wrapper for the NHITS model from NeuralForecast."""

    # Set these attributes to ensure that AutoGluon passes correct features to the model
    _supports_known_covariates: bool = True
    _supports_past_covariates: bool = True
    _supports_static_features: bool = True

    def preprocess(
        self,
        data: TimeSeriesDataFrame,
        known_covariates: Optional[TimeSeriesDataFrame] = None,
        is_train: bool = False,
        **kwargs,
    ) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
        """Method that implements model-specific preprocessing logic.

        This method is called on all data that is passed to `_fit` and `_predict` methods.
        """
        # NeuralForecast cannot handle missing values represented by NaN. Therefore, we
        # need to impute them before the data is passed to the model. First, we
        # forward-fill and backward-fill all time series
        data = data.fill_missing_values()
        # Some time series might consist completely of missing values, so the previous
        # line has no effect on them. We fill them with 0.0
        data = data.fill_missing_values(method="constant", value=0.0)
        # Some models (e.g., Chronos) can natively handle NaNs - for them we don't need
        # to define a custom preprocessing logic
        return data, known_covariates

    def _get_default_hyperparameters(self) -> dict:
        """Default hyperparameters that will be provided to the inner model, i.e., the
        NHITS implementation in neuralforecast. """
        import torch
        from neuralforecast.losses.pytorch import MQLoss

        default_hyperparameters = dict(
            loss=MQLoss(quantiles=self.quantile_levels),
            input_size=2 * self.prediction_length,
            scaler_type="standard",
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False,
            accelerator="cpu",
            # The model wrapper should handle any time series length - even time series
            # with 1 observation
            start_padding_enabled=True,
            # NeuralForecast requires that names of the past/future/static covariates are
            # passed as model arguments. AutoGluon models have access to this information
            # using the `metadata` attribute that is set automatically at model creation.
            #
            # Note that NeuralForecast does not support categorical covariates, so we
            # only use the real-valued covariates here. To use categorical features in
            # you wrapper, you need to either use techniques like one-hot-encoding, or
            # rely on models that natively handle categorical features.
            futr_exog_list=self.covariate_metadata.known_covariates_real,
            hist_exog_list=self.covariate_metadata.past_covariates_real,
            stat_exog_list=self.covariate_metadata.static_features_real,
        )

        if torch.cuda.is_available():
            default_hyperparameters["accelerator"] = "gpu"
            default_hyperparameters["devices"] = 1

        return default_hyperparameters

    def _fit(
        self,
        train_data: TimeSeriesDataFrame,
        val_data: Optional[TimeSeriesDataFrame] = None,
        time_limit: Optional[float] = None,
        **kwargs,
    ) -> None:
        """Fit the model on the available training data."""
        print("Entering the `_fit` method")

        # We lazily import other libraries inside the _fit method. This reduces the
        # import time for autogluon and ensures that even if one model has some problems
        # with dependencies, the training process won't crash
        from neuralforecast import NeuralForecast
        from neuralforecast.models import NHITS

        # It's important to ensure that the model respects the time_limit during `fit`.
        # Since NeuralForecast is based on PyTorch-Lightning, this can be easily enforced
        # using the `max_time` argument to `pl.Trainer`. For other model types such as
        # ARIMA implementing the time_limit logic may require a lot of work.
        hyperparameter_overrides = {}
        if time_limit is not None:
            hyperparameter_overrides = {"max_time": {"seconds": time_limit}}

        # The method `get_hyperparameters()` returns the model hyperparameters in
        # `_get_default_hyperparameters` overridden with the hyperparameters provided by the user in
        # `predictor.fit(..., hyperparameters={NHITSModel: {}})`. We override these with other
        # hyperparameters available at training time.
        model_params = self.get_hyperparameters() | hyperparameter_overrides
        print(f"Hyperparameters:\n{pprint.pformat(model_params, sort_dicts=False)}")

        model = NHITS(h=self.prediction_length, **model_params)
        self.nf = NeuralForecast(models=[model], freq=self.freq)

        # Convert data into a format expected by the model. NeuralForecast expects time
        # series data in pandas.DataFrame format that is quite similar to AutoGluon, so
        # the transformation is very easy.
        #
        # Note that the `preprocess` method was already applied to train_data and val_data.
        train_df, static_df = self._to_neuralforecast_format(train_data)
        self.nf.fit(
            train_df,
            static_df=static_df,
            id_col="item_id",
            time_col="timestamp",
            target_col=self.target,
        )
        print("Exiting the `_fit` method")

    def _to_neuralforecast_format(self, data: TimeSeriesDataFrame) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
        """Convert a TimeSeriesDataFrame to the format expected by NeuralForecast."""
        df = data.to_data_frame().reset_index()
        # Drop the categorical covariates to avoid NeuralForecast errors
        df = df.drop(columns=self.covariate_metadata.covariates_cat)
        static_df = data.static_features
        if len(self.covariate_metadata.static_features_real) > 0:
            static_df = static_df.reset_index()
            static_df = static_df.drop(columns=self.covariate_metadata.static_features_cat)
        return df, static_df

    def _predict(
        self,
        data: TimeSeriesDataFrame,
        known_covariates: Optional[TimeSeriesDataFrame] = None,
        **kwargs,
    ) -> TimeSeriesDataFrame:
        """Predict future target given the historical time series data and the future values of known_covariates."""
        print("Entering the `_predict` method")

        from neuralforecast.losses.pytorch import quantiles_to_outputs

        df, static_df = self._to_neuralforecast_format(data)
        if len(self.covariate_metadata.known_covariates_real) > 0:
            futr_df, _ = self._to_neuralforecast_format(known_covariates)
        else:
            futr_df = None

        with warning_filter():
            predictions = self.nf.predict(df, static_df=static_df, futr_df=futr_df)

        # predictions must be a TimeSeriesDataFrame with columns
        # ["mean"] + [str(q) for q in self.quantile_levels]
        model_name = str(self.nf.models[0])
        rename_columns = {
            f"{model_name}{suffix}": str(quantile)
            for quantile, suffix in zip(*quantiles_to_outputs(self.quantile_levels))
        }
        predictions = predictions.rename(columns=rename_columns)
        predictions["mean"] = predictions["0.5"]
        predictions = TimeSeriesDataFrame(predictions)
        return predictions

For convenience, here is an overview of the main constraints on the inputs and outputs of different methods.

  • Input data received by _fit and _predict methods satisfies

    • the index is sorted by (item_id, timestamp)

    • timestamps of observations have a regular frequency corresponding to self.freq

    • column self.target contains the target values of the time series

    • target column might contain missing values represented by NaN

    • data may contain covariates (incl. static features) with schema described in self.covariate_metadata

      • real-valued covariates have dtype float32

      • categorical covariates have dtype category

      • covariates do not contain any missing values

    • static features, if present, are available as data.static_features

  • Predictions returned by _predict must satisfy:

    • returns predictions as a TimeSeriesDataFrame object

    • predictions contain columns ["mean"] + [str(q) for q in self.quantile_levels] containing the point and quantile forecasts, respectively

    • the index of predictions contains exactly self.prediction_length future time steps of each time series present in data

    • the frequency of the prediction timestamps matches self.freq

    • the index of predictions is sorted by (item_id, timestamp)

    • predictions contain no missing values represented by NaN and no gaps

  • The runtime of _fit method should not exceed time_limit seconds, if time_limit is provided.

  • None of the methods should modify the data in-place. If modifications are needed, create a copy of the data first.

  • All methods should work even if some time series consist of all NaNs, or only have a single observation.


We will now use this wrapper in two modes:

  1. Standalone mode (outside the TimeSeriesPredictor).

    • This mode should be used for development and debugging. In this case, we need to take manually take care of preprocessing and model configuration.

  2. Inside the TimeSeriesPredictor.

    • This mode makes it easy to combine & compare the custom model with other models available in AutoGluon. The main purpose of writing a custom model wrapper is to use it in this mode.

Load and preprocess the data

First, we load the Grocery Sales dataset that we will use for development and evaluation.

from autogluon.timeseries import TimeSeriesDataFrame

raw_data = TimeSeriesDataFrame.from_path(
    "https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv",
    static_features_path="https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/static.csv",
)
raw_data.head()
scaled_price promotion_email promotion_homepage unit_sales
item_id timestamp
1062_101 2018-01-01 0.879130 0.0 0.0 636.0
2018-01-08 0.994517 0.0 0.0 123.0
2018-01-15 1.005513 0.0 0.0 391.0
2018-01-22 1.000000 0.0 0.0 339.0
2018-01-29 0.883309 0.0 0.0 661.0
raw_data.static_features.head()
product_code product_category product_subcategory location_code
item_id
1062_101 1062 Beverages Fruit Juice Mango 101
1062_102 1062 Beverages Fruit Juice Mango 102
1062_104 1062 Beverages Fruit Juice Mango 104
1062_106 1062 Beverages Fruit Juice Mango 106
1062_108 1062 Beverages Fruit Juice Mango 108
print("Types of the columns in raw data:")
print(raw_data.dtypes)
print("\nTypes of the columns in raw static features:")
print(raw_data.static_features.dtypes)

print("\nNumber of missing values per column:")
print(raw_data.isna().sum())
Types of the columns in raw data:
scaled_price          float64
promotion_email       float64
promotion_homepage    float64
unit_sales            float64
dtype: object

Types of the columns in raw static features:
product_code            int64
product_category       object
product_subcategory    object
location_code           int64
dtype: object

Number of missing values per column:
scaled_price          714
promotion_email       714
promotion_homepage    714
unit_sales            714
dtype: int64

Define the forecasting task

prediction_length = 7  # number of future steps to predict
target = "unit_sales"  # target column
known_covariates_names = ["promotion_email", "promotion_homepage"]  # covariates known in the future

Before we use the model in standalone mode, we need to apply the general AutoGluon preprocessing to the data.

The TimeSeriesFeatureGenerator captures preprocessing steps like normalizing the data types and imputing the missing values in the covariates.

from autogluon.timeseries.utils.features import TimeSeriesFeatureGenerator

feature_generator = TimeSeriesFeatureGenerator(target=target, known_covariates_names=known_covariates_names)
data = feature_generator.fit_transform(raw_data)
print("Types of the columns in preprocessed data:")
print(data.dtypes)
print("\nTypes of the columns in preprocessed static features:")
print(data.static_features.dtypes)

print("\nNumber of missing values per column:")
print(data.isna().sum())
Types of the columns in preprocessed data:
unit_sales            float64
promotion_email       float32
promotion_homepage    float32
scaled_price          float32
dtype: object

Types of the columns in preprocessed static features:
product_category       category
product_subcategory    category
product_code            float32
location_code           float32
dtype: object

Number of missing values per column:
unit_sales            714
promotion_email         0
promotion_homepage      0
scaled_price            0
dtype: int64

Using the custom model in standalone mode

Using the model in standalone mode is useful for debugging our implementation. Once we make sure that all methods work as expected, we will use the model inside the TimeSeriesPredictor.

Training

We are now ready to train the custom model on the preprocessed data.

When using the model in standalone mode, we need to manually configure its parameters.

model = NHITSModel(
    prediction_length=prediction_length,
    target=target,
    covariate_metadata=feature_generator.covariate_metadata,
    freq=data.freq,
    quantile_levels=[0.1, 0.5, 0.9],
)
model.fit(train_data=data, time_limit=20)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.275092667600006}}
Exiting the `_fit` method
NHITS

Predicting and scoring

past_data, known_covariates = data.get_model_inputs_for_scoring(
    prediction_length=prediction_length,
    known_covariates_names=known_covariates_names,
)
predictions = model.predict(past_data, known_covariates)
predictions.head()
Entering the `_predict` method
0.1 0.5 0.9 mean
item_id timestamp
1062_101 2018-06-18 182.176056 330.021057 556.481567 330.021057
2018-06-25 175.307068 324.581757 562.373596 324.581757
2018-07-02 177.127304 328.814392 567.216675 328.814392
2018-07-09 176.450790 325.529816 578.507141 325.529816
2018-07-16 175.310501 327.126129 579.647583 327.126129
model.score(data)
Entering the `_predict` method
np.float64(-0.3446280062131549)

Using the custom model inside the TimeSeriesPredictor

After we made sure that our custom model works in standalone mode, we can pass it to the TimeSeriesPredictor alongside other models.

from autogluon.timeseries import TimeSeriesPredictor

train_data, test_data = raw_data.train_test_split(prediction_length)

predictor = TimeSeriesPredictor(
    prediction_length=prediction_length,
    target=target,
    known_covariates_names=known_covariates_names,
)

predictor.fit(
    train_data,
    hyperparameters={
        "Naive": {},
        "Chronos": {"model_path": "bolt_small"},
        "ETS": {},
        NHITSModel: {},
    },
    time_limit=120,
)
Beginning AutoGluon training... Time limit = 120s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250527_234413'
=================== System Info ===================
AutoGluon Version:  1.3.2b20250527
Python Version:     3.11.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
GPU Count:          1
Memory Avail:       27.75 GB / 30.95 GB (89.7%)
Disk Space Avail:   211.57 GB / 255.99 GB (82.6%)
===================================================

Fitting with arguments:
{'enable_ensemble': True,
 'eval_metric': WQL,
 'hyperparameters': {<class '__main__.NHITSModel'>: {},
                     'Chronos': {'model_path': 'bolt_small'},
                     'ETS': {},
                     'Naive': {}},
 'known_covariates_names': ['promotion_email', 'promotion_homepage'],
 'num_val_windows': 1,
 'prediction_length': 7,
 'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
 'random_seed': 123,
 'refit_every_n_windows': 1,
 'refit_full': False,
 'skip_model_selection': False,
 'target': 'unit_sales',
 'time_limit': 120,
 'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).

Provided data contains following columns:
	target: 'unit_sales'
	known_covariates:
		categorical:        []
		continuous (float): ['promotion_email', 'promotion_homepage']
	past_covariates:
		categorical:        []
		continuous (float): ['scaled_price']
	static_features:
		categorical:        ['product_category', 'product_subcategory']
		continuous (float): ['product_code', 'location_code']

To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit

AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
	This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================

Starting training. Start time is 2025-05-27 23:44:13
Models that will be trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS']
Training timeseries model Naive. Training for up to 24.0s of the 119.9s of remaining time.
	-0.5412       = Validation score (-WQL)
	0.04    s     = Training runtime
	1.96    s     = Validation (prediction) runtime
Training timeseries model ETS. Training for up to 29.5s of the 117.9s of remaining time.
	-0.7039       = Validation score (-WQL)
	0.04    s     = Training runtime
	0.89    s     = Validation (prediction) runtime
Training timeseries model Chronos[bolt_small]. Training for up to 39.0s of the 117.0s of remaining time.
	-0.3320       = Validation score (-WQL)
	0.62    s     = Training runtime
	1.25    s     = Validation (prediction) runtime
Training timeseries model NHITS. Training for up to 57.5s of the 115.1s of remaining time.
	-0.4681       = Validation score (-WQL)
	20.06   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Fitting simple weighted ensemble.
	Ensemble weights: {'Chronos[bolt_small]': 0.97, 'NHITS': 0.03}
	-0.3320       = Validation score (-WQL)
	0.53    s     = Training runtime
	1.36    s     = Validation (prediction) runtime
Training complete. Models trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS', 'WeightedEnsemble']
Total runtime: 25.58 s
Best model: Chronos[bolt_small]
Best model score: -0.3320
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 51.78616878959283}}
Exiting the `_fit` method
Entering the `_predict` method
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f27ad9cd1d0>

Note that when we use the custom model inside the predictor, we don’t need to worry about:

  • manually configuring the model (setting freq, prediction_length)

  • preprocessing the data using TimeSeriesFeatureGenerator

  • setting the time limits

The TimeSeriesPredictor automatically takes care of all above aspects.

We can also easily compare our custom model with other model trained by the predictor.

predictor.leaderboard(test_data)
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
model score_test score_val pred_time_test pred_time_val fit_time_marginal fit_order
0 WeightedEnsemble -0.314278 -0.332025 0.698832 1.358480 0.528992 5
1 Chronos[bolt_small] -0.315786 -0.331984 0.554468 1.253759 0.617178 3
2 NHITS -0.397008 -0.468127 0.141845 0.104721 20.058523 4
3 ETS -0.459021 -0.703868 0.270355 0.888533 0.041151 2
4 Naive -0.512205 -0.541231 0.187525 1.963237 0.035905 1

We can also take advantage of other predictor functionality such as feature_importance.

predictor.feature_importance(test_data, model="NHITS")
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Computing feature importance
importance stdev n p99_low p99_high
product_category 0.000000 0.000000 5.0 0.000000 0.000000
product_subcategory 0.000000 0.000000 5.0 0.000000 0.000000
product_code -0.000775 0.010627 5.0 -0.022656 0.021106
location_code 0.000092 0.000285 5.0 -0.000494 0.000679
promotion_email 0.007981 0.009743 5.0 -0.012079 0.028041
promotion_homepage 0.005051 0.007467 5.0 -0.010324 0.020426
scaled_price -0.000088 0.000898 5.0 -0.001937 0.001761

As expected, features product_category and product_subcategory have zero importance because our implementation ignores categorical features.


Here is how we can train multiple versions of the custom model with different hyperparameter configurations

predictor = TimeSeriesPredictor(
    prediction_length=prediction_length,
    target=target,
    known_covariates_names=known_covariates_names,
)
predictor.fit(
    train_data,
    hyperparameters={
        NHITSModel: [
            {},  # default hyperparameters
            {"input_size": 20},  # custom input_size
            {"scaler_type": "robust"},  # custom scaler_type
        ]
    },
    time_limit=60,
)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.471844925398273}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 20,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.839514362633167}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'robust',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 14.406984723390542}}
Exiting the `_fit` method
Entering the `_predict` method
Beginning AutoGluon training... Time limit = 60s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250527_234444'
=================== System Info ===================
AutoGluon Version:  1.3.2b20250527
Python Version:     3.11.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
GPU Count:          1
Memory Avail:       26.90 GB / 30.95 GB (86.9%)
Disk Space Avail:   211.38 GB / 255.99 GB (82.6%)
===================================================

Fitting with arguments:
{'enable_ensemble': True,
 'eval_metric': WQL,
 'hyperparameters': {<class '__main__.NHITSModel'>: [{},
                                                     {'input_size': 20},
                                                     {'scaler_type': 'robust'}]},
 'known_covariates_names': ['promotion_email', 'promotion_homepage'],
 'num_val_windows': 1,
 'prediction_length': 7,
 'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
 'random_seed': 123,
 'refit_every_n_windows': 1,
 'refit_full': False,
 'skip_model_selection': False,
 'target': 'unit_sales',
 'time_limit': 60,
 'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).

Provided data contains following columns:
	target: 'unit_sales'
	known_covariates:
		categorical:        []
		continuous (float): ['promotion_email', 'promotion_homepage']
	past_covariates:
		categorical:        []
		continuous (float): ['scaled_price']
	static_features:
		categorical:        ['product_category', 'product_subcategory']
		continuous (float): ['product_code', 'location_code']

To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit

AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
	This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================

Starting training. Start time is 2025-05-27 23:44:44
Models that will be trained: ['NHITS', 'NHITS_2', 'NHITS_3']
Training timeseries model NHITS. Training for up to 15.0s of the 59.9s of remaining time.
	-0.4832       = Validation score (-WQL)
	13.62   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Training timeseries model NHITS_2. Training for up to 15.4s of the 46.2s of remaining time.
	-0.4523       = Validation score (-WQL)
	13.99   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Training timeseries model NHITS_3. Training for up to 16.0s of the 32.0s of remaining time.
	-0.3929       = Validation score (-WQL)
	14.57   s     = Training runtime
	0.11    s     = Validation (prediction) runtime
Fitting simple weighted ensemble.
	Ensemble weights: {'NHITS_2': 0.1, 'NHITS_3': 0.9}
	-0.3920       = Validation score (-WQL)
	0.41    s     = Training runtime
	0.21    s     = Validation (prediction) runtime
Training complete. Models trained: ['NHITS', 'NHITS_2', 'NHITS_3', 'WeightedEnsemble']
Total runtime: 43.03 s
Best model: WeightedEnsemble
Best model score: -0.3920
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f27a01c48d0>
predictor.leaderboard(test_data)
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
model score_test score_val pred_time_test pred_time_val fit_time_marginal fit_order
0 NHITS_2 -0.401187 -0.452312 0.127493 0.100950 13.990083 2
1 NHITS -0.408660 -0.483168 0.125887 0.098063 13.619061 1
2 WeightedEnsemble -0.417551 -0.391991 0.260669 0.206372 0.407155 4
3 NHITS_3 -0.430970 -0.392894 0.131381 0.105422 14.566792 3

Wrapping up

That’s all it takes to add a custom forecasting model to AutoGluon. If you create a custom model, consider submitting a PR so that we can add it officially to AutoGluon!

For more tutorials, refer to Forecasting Time Series - Quick Start and Forecasting Time Series - In Depth.