Source code for autogluon.features.generators.rename
import copy
import logging
from pandas import DataFrame
from .abstract import AbstractFeatureGenerator
logger = logging.getLogger(__name__)
[docs]class RenameFeatureGenerator(AbstractFeatureGenerator):
    """
    RenameFeatureGenerator renames the columns without altering their values.
    This can be used to avoid column name collisions when transforming the same feature in multiple ways, or to highlight that a feature was derived from a particular pipeline.
    Parameters
    ----------
    name_prefix : str, default None
        Name prefix to add to all output feature names.
    name_suffix : str, default None
        Name suffix to add to all output feature names.
    inplace : bool, default False
        If True, then the column names are renamed inplace without copying the input data.
        This will alter the input data outside of the scope of this function.
    **kwargs :
        Refer to :class:`AbstractFeatureGenerator` documentation for details on valid key word arguments.
    """
    def __init__(self, name_prefix=None, name_suffix=None, inplace=False, **kwargs):
        super().__init__(**kwargs)
        self._name_prefix = name_prefix
        self._name_suffix = name_suffix
        self.inplace = inplace
        self._is_updated_name = None
    def _fit_transform(self, X: DataFrame, **kwargs) -> (DataFrame, dict):
        column_rename_map, self._is_updated_name = self._get_renamed_features(X)
        if not self.inplace:
            X = copy.deepcopy(X)
        X.columns = [column_rename_map.get(col, col) for col in X.columns]
        feature_metadata_out = self.feature_metadata_in.rename_features(column_rename_map)
        return X, feature_metadata_out.type_group_map_special
    def _transform(self, X: DataFrame) -> DataFrame:
        if self._is_updated_name:
            if not self.inplace:
                X = copy.deepcopy(X)
            X.columns = self.features_out
        return X
    def _get_renamed_features(self, X: DataFrame) -> (DataFrame, dict):
        X_columns_orig = list(X.columns)
        X_columns_new = list(X.columns)
        if self._name_prefix:
            X_columns_new = [self._name_prefix + column for column in X_columns_new]
        if self._name_suffix:
            X_columns_new = [column + self._name_suffix for column in X_columns_new]
        if X_columns_orig != X_columns_new:
            is_updated_name = True
        else:
            is_updated_name = False
        column_rename_map = {orig: new for orig, new in zip(X_columns_orig, X_columns_new)}
        return column_rename_map, is_updated_name
    @staticmethod
    def get_default_infer_features_in_args() -> dict:
        return dict()
    def _more_tags(self):
        return {
            'feature_interactions': False,
            'allow_post_generators': False,  # TODO: This might not be necessary anymore
        }
