Source code for autogluon.tabular.models.fasttext.fasttext_model

__all__ = ["FastTextModel"]

import contextlib
import gc
import logging
import os
import psutil
import tempfile

import numpy as np
import pandas as pd

from autogluon.core.constants import BINARY, MULTICLASS
from autogluon.core.features.types import S_TEXT

from autogluon.core.models import AbstractModel
from autogluon.core.models.abstract.model_trial import skip_hpo
from .hyperparameters.parameters import get_param_baseline

logger = logging.getLogger(__name__)


def try_import_fasttext():
    try:
        import fasttext

        _ = fasttext.__file__
    except Exception:
        raise ImportError('Import fasttext failed. Please run "pip install fasttext"')


[docs]class FastTextModel(AbstractModel): model_bin_file_name = "fasttext.ftz" def __init__(self, **kwargs): super().__init__(**kwargs) self._load_model = None # Whether to load inner model when loading. def _set_default_params(self): default_params = get_param_baseline() for param, val in default_params.items(): self._set_default_param_value(param, val) # TODO: Investigate allowing categorical features as well def _get_default_auxiliary_params(self) -> dict: default_auxiliary_params = super()._get_default_auxiliary_params() extra_auxiliary_params = dict( get_features_kwargs=dict( required_special_types=[S_TEXT], ) ) default_auxiliary_params.update(extra_auxiliary_params) return default_auxiliary_params @classmethod def _get_default_ag_args(cls) -> dict: default_ag_args = super()._get_default_ag_args() extra_ag_args = {'valid_stacker': False, 'problem_types': [BINARY, MULTICLASS]} default_ag_args.update(extra_ag_args) return default_ag_args def _fit(self, X, y, sample_weight=None, **kwargs): if self.problem_type not in (BINARY, MULTICLASS): raise ValueError( "FastText model only supports binary or multiclass classification" ) try_import_fasttext() import fasttext params = self._get_model_params() quantize_model = params.pop('quantize_model', True) verbosity = kwargs.get('verbosity', 2) if 'verbose' not in params: if verbosity <= 2: params['verbose'] = 0 elif verbosity == 3: params['verbose'] = 1 else: params['verbose'] = 2 if sample_weight is not None: logger.log(15, "sample_weight not yet supported for FastTextModel, this model will ignore them in training.") X = self.preprocess(X) self._label_dtype = y.dtype self._label_map = {label: f"__label__{i}" for i, label in enumerate(y.unique())} self._label_inv_map = {v: k for k, v in self._label_map.items()} np.random.seed(0) idxs = np.random.permutation(list(range(len(X)))) with tempfile.NamedTemporaryFile(mode="w+t") as f: logger.debug("generate training data") for label, text in zip(y.iloc[idxs], (X[i] for i in idxs)): f.write(f"{self._label_map[label]} {text}\n") f.flush() mem_start = psutil.Process().memory_info().rss logger.debug("train FastText model") self.model = fasttext.train_supervised(f.name, **params) if quantize_model: self.model.quantize(input=f.name, retrain=True) gc.collect() mem_curr = psutil.Process().memory_info().rss self._model_size_estimate = max(mem_curr - mem_start, 100000000 if quantize_model else 800000000) logger.debug("finish training FastText model") # TODO: move logic to self._preprocess_nonadaptive() # TODO: text features: alternate text preprocessing steps # TODO: categorical features: special encoding: <feature name>_<feature value> def _preprocess(self, X: pd.DataFrame, **kwargs) -> list: X = super()._preprocess(X, **kwargs) text_col = ( X .astype(str) .fillna(" ") .apply(lambda r: " ".join(v for v in r.values), axis=1) .str.lower() .str.replace("<.*?>", " ") # remove html tags # .str.replace('''(\\d[\\d,]*)(\\.\\d+)?''', ' __NUMBER__ ') # process numbers preserve dot .str.replace("""([\\W])""", " \\1 ") # separate special characters .str.replace("\\s", " ") .str.replace("[ ]+", " ") ) return text_col.to_list() def predict(self, X: pd.DataFrame, **kwargs) -> np.ndarray: X = self.preprocess(X, **kwargs) pred_labels, pred_probs = self.model.predict(X) y_pred = np.array( [self._label_inv_map[labels[0]] for labels in pred_labels], dtype=self._label_dtype, ) return y_pred def _predict_proba(self, X: pd.DataFrame, **kwargs) -> np.ndarray: X = self.preprocess(X, **kwargs) pred_labels, pred_probs = self.model.predict(X, k=len(self.model.labels)) recs = [] for labels, probs in zip(pred_labels, pred_probs): recs.append( dict(zip((self._label_inv_map[label] for label in labels), probs)) ) y_pred_proba: np.ndarray = pd.DataFrame(recs).sort_index(axis=1).values return self._convert_proba_to_unified_form(y_pred_proba) def save(self, path: str = None, verbose=True) -> str: self._load_model = self.model is not None # pickle model parts __model = self.model self.model = None path = super().save(path=path, verbose=verbose) self.model = __model # save fasttext model: fasttext model cannot be pickled; saved it separately # TODO: s3 support if self._load_model: fasttext_model_file_name = path + self.model_bin_file_name self.model.save_model(fasttext_model_file_name) self._load_model = None return path @classmethod def load(cls, path: str, reset_paths=True, verbose=True): model: FastTextModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose) # load binary fasttext model if model._load_model: try_import_fasttext() import fasttext fasttext_model_file_name = model.path + cls.model_bin_file_name # TODO: hack to subpress a deprecation warning from fasttext # remove it once offcial fasttext is updated beyond 0.9.2 # https://github.com/facebookresearch/fastText/issues/1067 with open(os.devnull, 'w') as f, contextlib.redirect_stderr(f): model.model = fasttext.load_model(fasttext_model_file_name) model._load_model = None return model def get_memory_size(self) -> int: return self._model_size_estimate # TODO: Add HPO def _hyperparameter_tune(self, **kwargs): return skip_hpo(self, **kwargs)