Source code for autogluon.tabular.models.tabm.tabm_model

"""
Code Adapted from TabArena: https://github.com/autogluon/tabarena/blob/main/tabarena/tabarena/benchmark/models/ag/tabm/tabm_model.py
Partially adapted from pytabkit's TabM implementation.
"""

from __future__ import annotations

import logging
import time

import pandas as pd

from autogluon.common.utils.resource_utils import ResourceManager
from autogluon.tabular import __version__
from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel

logger = logging.getLogger(__name__)


[docs] class TabMModel(AbstractTorchModel): """ TabM is an efficient ensemble of MLPs that is trained simultaneously with mostly shared parameters. TabM is one of the top performing methods overall on TabArena-v0.1: https://tabarena.ai Paper: TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling Authors: Yury Gorishniy, Akim Kotelnikov, Artem Babenko Codebase: https://github.com/yandex-research/tabm License: Apache-2.0 Partially adapted from pytabkit's TabM implementation. .. versionadded:: 1.4.0 """ ag_key = "TABM" ag_name = "TabM" ag_priority = 85 seed_name = "random_state" def __init__(self, **kwargs): super().__init__(**kwargs) self._imputer = None self._features_to_impute = None self._features_to_keep = None self._indicator_columns = None self._features_bool = None self._bool_to_cat = None def _fit( self, X: pd.DataFrame, y: pd.Series, X_val: pd.DataFrame = None, y_val: pd.Series = None, time_limit: float | None = None, num_cpus: int = 1, num_gpus: float = 0, **kwargs, ): start_time = time.time() try: # imports various dependencies such as torch from torch.cuda import is_available from ._tabm_internal import TabMImplementation except ImportError as err: logger.log( 40, f"\tFailed to import tabm! To use the TabM model, " f"do: `pip install autogluon.tabular[tabm]=={__version__}`.", ) raise err device = "cpu" if num_gpus == 0 else "cuda" if (device == "cuda") and (not is_available()): # FIXME: warn instead and switch to CPU. raise AssertionError( "Fit specified to use GPU, but CUDA is not available on this machine. " "Please switch to CPU usage instead.", ) if X_val is None: from autogluon.core.utils import generate_train_test_split X, X_val, y, y_val = generate_train_test_split( X=X, y=y, problem_type=self.problem_type, test_size=0.2, random_state=0, ) hyp = self._get_model_params() bool_to_cat = hyp.pop("bool_to_cat", True) X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat) if X_val is not None: X_val = self.preprocess(X_val) self.model = TabMImplementation( n_threads=num_cpus, device=device, problem_type=self.problem_type, early_stopping_metric=self.stopping_metric, **hyp, ) self.model.fit( X_train=X, y_train=y, X_val=X_val, y_val=y_val, cat_col_names=X.select_dtypes(include="category").columns.tolist(), time_to_fit_in_seconds=time_limit - (time.time() - start_time) if time_limit is not None else None, ) # FIXME: bool_to_cat is a hack: Maybe move to abstract model? def _preprocess( self, X: pd.DataFrame, is_train: bool = False, bool_to_cat: bool = False, **kwargs, ) -> pd.DataFrame: """Imputes missing values via the mean and adds indicator columns for numerical features. Converts indicator columns to categorical features to avoid them being treated as numerical by RealMLP. """ X = super()._preprocess(X, **kwargs) if is_train: self._bool_to_cat = bool_to_cat self._features_bool = self._feature_metadata.get_features(required_special_types=["bool"]) if self._bool_to_cat and self._features_bool: # FIXME: Use CategoryFeatureGenerator? Or tell the model which is category X = X.copy(deep=True) X[self._features_bool] = X[self._features_bool].astype("category") return X def get_device(self) -> str: return self.model.device_.type def _set_device(self, device: str): device = self.to_torch_device(device) self.model.device_ = device self.model.model_ = self.model.model_.to(device) @classmethod def supported_problem_types(cls) -> list[str] | None: return ["binary", "multiclass", "regression"] def _get_default_stopping_metric(self): return self.eval_metric def _get_default_resources(self) -> tuple[int, int]: # Use only physical cores for better performance based on benchmarks num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True) num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True)) return num_cpus, num_gpus def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int: hyperparameters = self._get_model_params() return self.estimate_memory_usage_static( X=X, problem_type=self.problem_type, num_classes=self.num_classes, hyperparameters=hyperparameters, **kwargs, ) @classmethod def _estimate_memory_usage_static( cls, *, X: pd.DataFrame, hyperparameters: dict = None, num_classes: int | None = 1, **kwargs, ) -> int: """ Heuristic memory estimate that correlates strongly with RealMLP """ if num_classes is None: num_classes = 1 if hyperparameters is None: hyperparameters = {} cat_sizes = [] for col in X.select_dtypes(include=["category", "object"]): if isinstance(X[col], pd.CategoricalDtype): # Use .cat.codes for category dtype unique_codes = X[col].cat.codes.unique() else: # For object dtype, treat unique strings as codes unique_codes = X[col].astype("category").cat.codes.unique() cat_sizes.append(len(unique_codes)) n_numerical = len(X.select_dtypes(include=["number"]).columns) # TODO: This estimates very high memory usage, # we probably need to adjust batch size automatically to compensate mem_estimate_bytes = cls._estimate_tabm_ram( hyperparameters=hyperparameters, n_numerical=n_numerical, cat_sizes=cat_sizes, n_classes=num_classes, n_samples=len(X), ) return mem_estimate_bytes @classmethod def _estimate_tabm_ram( cls, hyperparameters: dict, n_numerical: int, cat_sizes: list[int], n_classes: int, n_samples: int, ) -> int: num_emb_n_bins = hyperparameters.get("num_emb_n_bins", 48) d_embedding = hyperparameters.get("d_embedding", 16) d_block = hyperparameters.get("d_block", 512) # not completely sure if this is hidden blocks or all blocks, taking the safe option below n_blocks = hyperparameters.get("n_blocks", "auto") if isinstance(n_blocks, str) and n_blocks == "auto": n_blocks = 3 batch_size = hyperparameters.get("batch_size", "auto") if isinstance(batch_size, str) and batch_size == "auto": batch_size = cls.get_tabm_auto_batch_size(n_samples=n_samples) tabm_k = hyperparameters.get("tabm_k", 32) predict_batch_size = hyperparameters.get("eval_batch_size", 1024) # not completely sure n_params_num_emb = n_numerical * (num_emb_n_bins + 1) * d_embedding n_params_mlp = ( (n_numerical + sum(cat_sizes)) * d_embedding * (d_block + tabm_k) + (n_blocks - 1) * d_block**2 + n_blocks * d_block + d_block * (1 + max(1, n_classes)) ) # 4 bytes per float, up to 5 copies of parameters (1 standard, 1 .grad, 2 adam, 1 best_epoch) mem_params = 4 * 5 * (n_params_num_emb + n_params_mlp) # compute number of floats in forward pass (per batch element) # todo: numerical embedding layer (not sure if this is entirely correct) n_floats_forward = n_numerical * (num_emb_n_bins + d_embedding) # before and after scale n_floats_forward += 2 * (sum(cat_sizes) + n_numerical * d_embedding) # 2 for pre-act, post-act n_floats_forward += n_blocks * 2 * d_block + 2 * max(1, n_classes) # 2 for forward and backward, 4 bytes per float mem_forward_backward = 4 * max(batch_size * 2, predict_batch_size) * n_floats_forward * tabm_k # * 8 is pessimistic for the long tensors in the forward pass, 4 would probably suffice mem_ds = n_samples * (4 * n_numerical + 8 * len(cat_sizes)) # some safety constants and offsets (the 5 is probably excessive) mem_total = 5 * mem_ds + 1.2 * mem_forward_backward + 1.2 * mem_params + 0.3 * (1024**3) return mem_total def _get_default_auxiliary_params(self) -> dict: default_auxiliary_params = super()._get_default_auxiliary_params() default_auxiliary_params.update( { "max_batch_size": 16384, # avoid excessive VRAM usage } ) return default_auxiliary_params @classmethod def get_tabm_auto_batch_size(cls, n_samples: int) -> int: # by Yury Gorishniy, inferred from the choices in the TabM paper. if n_samples < 2_800: return 32 if n_samples < 4_500: return 64 if n_samples < 6_400: return 128 if n_samples < 32_000: return 256 if n_samples < 108_000: return 512 return 1024 @classmethod def _class_tags(cls): return { "can_estimate_memory_usage_static": True, "reset_torch_threads": True, } def _more_tags(self) -> dict: # TODO: Need to add train params support, track best epoch # How to force stopping at a specific epoch? return {"can_refit_full": False}