Source code for autogluon.eda.visualization.anomaly

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import seaborn as sns

from .. import AnalysisState
from .base import AbstractVisualization
from .jupyter import JupyterMixin

__all__ = ["AnomalyScoresVisualization"]

from ..utils.common import expand_nested_args_into_nested_maps, get_empty_dict_if_none


[docs]class AnomalyScoresVisualization(AbstractVisualization, JupyterMixin): """ Visualize anomaly scores across datasets. The report depends on :py:class:`~autogluon.eda.analysis.anomaly.AnomalyDetectorAnalysis`, Parameters ---------- threshold_stds: float = 3, defines how many standard deviations from mean the scores will be marked as anomalies headers: bool, default = False if `True` then render headers namespace: Optional[str], default = None namespace to use; can be nested like `ns_a.ns_b.ns_c` fig_args: Optional[Dict[str, Any]] = None, kwargs to pass into visualization component. chart_args kwargs to pass into visualization component The chart contains two scatterpolots combined: normal and anomaly data points; both can be customized via passing additional arguments in `chart_args` - see the example below. Examples -------- >>> import autogluon.eda.analysis as eda >>> import autogluon.eda.visualization as viz >>> import autogluon.eda.auto as auto >>> import pandas as pd >>> import numpy as np >>> >>> df_train = pd.DataFrame(...) >>> df_test = pd.DataFrame(...) >>> label = 'target' >>> threshold_stds = 3 # mark 3 standard deviations score values as anomalies >>> >>> chart_args={ >>> 'normal.color': 'lightgrey', >>> 'anomaly.color': 'orange', >>> } >>> >>> state = auto.analyze( >>> train_data=df_train, >>> test_data=df_test, >>> label=label, >>> return_state=True, >>> anlz_facets=[ >>> eda.dataset.ProblemTypeControl(), >>> eda.transform.ApplyFeatureGenerator(category_to_numbers=True, children=[ >>> eda.anomaly.AnomalyDetectorAnalysis( >>> store_explainability_data=True # Store additional functions for explainability >>> ), >>> ]) >>> ], >>> viz_facets=[ >>> viz.anomaly.AnomalyScoresVisualization( >>> threshold_stds=threshold_stds, >>> headers=True, >>> fig_args=dict(figsize=(8, 4)), >>> **chart_args, # pass chart args customizations >>> ) >>> ] >>> ) >>> >>> # explain top anomalies >>> train_anomaly_scores = state.anomaly_detection.scores.train_data >>> anomaly_idx = train_anomaly_scores[train_anomaly_scores >= train_anomaly_scores.std() * threshold_stds] >>> anomaly_idx = anomaly_idx.sort_values(ascending=False).index >>> >>> auto.explain_rows( >>> # Use helper function stored via `store_explainability_data=True` >>> **state.anomaly_detection.explain_rows_fns.train_data(anomaly_idx[:3]), >>> plot='waterfall', >>> ) """ def __init__( self, threshold_stds: float = 3, headers: bool = False, namespace: Optional[str] = None, fig_args: Optional[Dict[str, Any]] = None, **chart_args, ) -> None: super().__init__(namespace, **chart_args) self.threshold_stds = threshold_stds self.headers = headers if fig_args is None: fig_args = {} self.fig_args = fig_args self.chart_args = expand_nested_args_into_nested_maps(get_empty_dict_if_none(chart_args)) def can_handle(self, state: AnalysisState) -> bool: return self.all_keys_must_be_present(state, "anomaly_detection") def _render(self, state: AnalysisState) -> None: scores = state.anomaly_detection.scores threshold = scores.train_data.std() * self.threshold_stds for ds, ds_scores in scores.items(): self.render_header_if_needed( state, f"`{ds}` anomalies for {self.threshold_stds}-sigma outlier scores", ds=ds ) data = ds_scores.reset_index(drop=True).reset_index() fig, ax = plt.subplots(**self.fig_args) common_chart_args = dict(ax=ax, x="index", y="score") sns.scatterplot( data=data[data.score < threshold], **common_chart_args, **self.chart_args.get("normal", {}) ) sns.scatterplot( data=data[data.score >= threshold], **common_chart_args, **self.chart_args.get("anomaly", {}) ) ax.axhline( y=threshold, color="r", linestyle="--", ) ax.text( x=0, y=threshold, s=f"{threshold:.4f}", color="red", rotation="vertical", horizontalalignment="right", verticalalignment="top", ) plt.tight_layout(h_pad=0.3, w_pad=0.5) plt.show(fig)