Source code for autogluon.eda.visualization.interaction
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.cluster import hierarchy as hc
from autogluon.common.features.types import R_BOOL, R_CATEGORY, R_DATETIME, R_FLOAT, R_INT, R_OBJECT
from ..state import AnalysisState
from .base import AbstractVisualization
from .jupyter import JupyterMixin
__all__ = ["CorrelationVisualization", "CorrelationSignificanceVisualization", "FeatureInteractionVisualization"]
class _AbstractCorrelationChart(AbstractVisualization, JupyterMixin, ABC):
def __init__(
self,
headers: bool = False,
namespace: Optional[str] = None,
fig_args: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(namespace, **kwargs)
self.headers = headers
if fig_args is None:
fig_args = {}
self.fig_args = fig_args
def _render_internal(self, state: AnalysisState, render_key: str, header: str, chart_args: Dict[str, Any]) -> None:
for ds, corr in state[render_key].items():
# Don't render single cell
cells_num = len(state.correlations[ds])
if cells_num <= 1:
continue
fig_args = self.fig_args.copy()
if "figsize" not in fig_args:
fig_args["figsize"] = (cells_num, cells_num)
if state.correlations_focus_field is not None:
focus_field_header = f"; focus: absolute correlation for `{state.correlations_focus_field}` >= `{state.correlations_focus_field_threshold}`"
else:
focus_field_header = ""
self.render_header_if_needed(state, f"`{ds}` - `{state.correlations_method}` {header}{focus_field_header}")
fig, ax = plt.subplots(**fig_args)
sns.heatmap(
corr,
annot=True,
ax=ax,
linewidths=0.5,
linecolor="lightgrey",
fmt=".2f",
square=True,
cbar_kws={"shrink": 0.5},
**chart_args,
)
plt.yticks(rotation=0)
plt.show(fig)
[docs]class CorrelationVisualization(_AbstractCorrelationChart):
"""
Display feature correlations matrix.
This report renders correlations between variable in a form of heatmap.
The details of the report to be rendered depend on the configuration of
:py:class:`~autogluon.eda.analysis.interaction.Correlation`
Parameters
----------
headers: bool, default = False
if `True` then render headers
namespace: Optional[str], default = None
namespace to use; can be nested like `ns_a.ns_b.ns_c`
fig_args: Optional[Dict[str, Any]], default = None,
kwargs to pass into chart figure
See Also
--------
:py:class:`~autogluon.eda.analysis.interaction.Correlation`
"""
def can_handle(self, state: AnalysisState) -> bool:
return "correlations" in state
def _render(self, state: AnalysisState) -> None:
args = {"vmin": 0 if state.correlations_method == "phik" else -1, "vmax": 1, "center": 0, "cmap": "Spectral"}
self._render_internal(state, "correlations", "correlation matrix", args)
class _AbstractFeatureInteractionPlotRenderer(ABC):
@abstractmethod
def _render(self, state, ds, params, param_types, ax, data, chart_args):
raise NotImplementedError # pragma: no cover
def render(self, state, ds, params, param_types, data, fig_args, chart_args):
fig, ax = plt.subplots(**fig_args)
self._render(state, ds, params, param_types, ax, data, chart_args)
plt.show(fig)
[docs]class CorrelationSignificanceVisualization(_AbstractCorrelationChart):
"""
Display feature correlations significance matrix.
This report renders correlations significance matrix in a form of heatmap.
The details of the report to be rendered depend on the configuration of
:py:class:`~autogluon.eda.analysis.interaction.Correlation` and
:py:class:`~autogluon.eda.analysis.interaction.CorrelationSignificance` analyses.
Parameters
----------
headers: bool, default = False
if `True` then render headers
namespace: Optional[str], default = None
namespace to use; can be nested like `ns_a.ns_b.ns_c`
fig_args: Optional[Dict[str, Any]] = None,
kwargs to pass into chart figure
See Also
--------
:py:class:`~autogluon.eda.analysis.interaction.Correlation`
:py:class:`~autogluon.eda.analysis.interaction.CorrelationSignificance`
"""
def can_handle(self, state: AnalysisState) -> bool:
return "significance_matrix" in state
def _render(self, state: AnalysisState) -> None:
args = {"center": 3, "vmax": 5, "cmap": "Spectral", "robust": True}
self._render_internal(state, "significance_matrix", "correlation significance matrix", args)
class FeatureDistanceAnalysisVisualization(AbstractVisualization, JupyterMixin):
"""
Parameters
----------
headers: bool, default = False
if `True` then render headers
namespace: Optional[str], default = None
namespace to use; can be nested like `ns_a.ns_b.ns_c`
fig_args: Optional[Dict[str, Any]] = None,
kwargs to pass into chart figure
kwargs
"""
def __init__(
self,
headers: bool = False,
namespace: Optional[str] = None,
fig_args: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(namespace, **kwargs)
self.headers = headers
if fig_args is None:
fig_args = {}
self.fig_args = fig_args
def can_handle(self, state: AnalysisState) -> bool:
return self.all_keys_must_be_present(state, "feature_distance")
def _render(self, state: AnalysisState) -> None:
fig_args = self.fig_args.copy()
if "figsize" not in fig_args:
fig_args["figsize"] = (12, len(state.feature_distance.columns) / 4)
fig, ax = plt.subplots(**fig_args)
default_args = dict(orientation="left")
ax.grid(False)
hc.dendrogram(
ax=ax,
Z=state.feature_distance.linkage,
labels=state.feature_distance.columns,
leaf_font_size=10,
**{**default_args, **self._kwargs},
)
plt.show(fig)
if len(state.feature_distance.near_duplicates) > 0:
message = (
f"**The following feature groups are considered as near-duplicates**:\n\n"
f"Distance threshold: <= `{state.feature_distance.near_duplicates_threshold}`. "
f"Consider keeping only some of the columns within each group:\n"
)
for group in state.feature_distance.near_duplicates:
message += f'\n - `{"`, `".join(sorted(group["nodes"]))}` - distance `{group["distance"]:.2f}`'
self.render_markdown(message)
[docs]class FeatureInteractionVisualization(AbstractVisualization, JupyterMixin):
"""
Feature interaction visualization.
This report renders feature interaction analysis results.
The details of the report to be rendered depend on the variable types combination in `x`/`y`/`hue`.
`key` is used to link analysis and visualization - this allows to have multiple analyses/visualizations in one composite analysis.
Parameters
----------
key: str
key used to store the analysis in the state; the value is placed in the state by FeatureInteraction.
If the key is not provided, then use one of theform: 'x:A|y:B|hue:C' (omit corresponding x/y/hue if the value not provided)
See also :class:`autogluon.eda.analysis.interaction.FeatureInteraction`
numeric_as_categorical_threshold
headers: bool, default = False
if `True` then render headers
namespace: Optional[str], default = None
namespace to use; can be nested like `ns_a.ns_b.ns_c`
fig_args: Optional[Dict[str, Any]] = None,
kwargs to pass into chart figure
kwargs
parameters to pass as a chart args
See Also
--------
:py:class:`~autogluon.eda.analysis.interaction.FeatureInteraction`
"""
def __init__(
self,
key: str,
numeric_as_categorical_threshold: int = 20,
max_categories_to_consider_render: int = 30,
headers: bool = False,
namespace: Optional[str] = None,
fig_args: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(namespace, **kwargs)
self.key = key
self.headers = headers
self.numeric_as_categorical_threshold = numeric_as_categorical_threshold
self.max_categories_to_consider_render = max_categories_to_consider_render
if fig_args is None:
fig_args = {}
self.fig_args = fig_args
def can_handle(self, state: AnalysisState) -> bool:
return self.all_keys_must_be_present(state, "interactions", "raw_type")
def _render(self, state: AnalysisState) -> None:
for ds in state.interactions.keys():
if self.key not in state.interactions[ds]:
continue
interaction = state.interactions[ds][self.key]
interaction_features = interaction["features"]
df = interaction["data"].copy()
x, x_type = self._get_value_and_type(ds, df, state, interaction_features, "x")
y, y_type = self._get_value_and_type(ds, df, state, interaction_features, "y")
hue, hue_type = self._get_value_and_type(ds, df, state, interaction_features, "hue")
# Don't render high-cardinality category variables
features = "/".join(
[f"`{interaction_features[k]}`" for k in ["x", "y", "hue"] if k in interaction_features]
)
for f, t in [(x, x_type), (y, y_type), (hue, hue_type)]:
if t == "category" and df[f].nunique() > self.max_categories_to_consider_render:
self.render_markdown(
f"Interaction {features} is not rendered due to `{f}` "
f"having too many categories (`{df[f].nunique()}` > `{self.max_categories_to_consider_render}`) "
f"for comfortable read."
)
return
y, y_type, hue, hue_type = self._swap_y_and_hue_if_necessary(x_type, y, y_type, hue, hue_type)
renderer_cls: Optional[Type[_AbstractFeatureInteractionPlotRenderer]] = self._get_chart_renderer(
x_type, y_type, hue_type
)
if renderer_cls is None:
return
renderer: _AbstractFeatureInteractionPlotRenderer = renderer_cls() # Create instance
df = self._convert_categoricals_to_objects(df, x, x_type, y, y_type, hue, hue_type)
chart_args, data, is_single_var = self._prepare_chart_args(df, x, x_type, y, y_type, hue)
if self.headers:
prefix = "" if is_single_var else "Feature interaction between "
self.render_header_if_needed(state, f"{prefix}{features} in `{ds}`")
fig_args = self.fig_args.copy()
if "figsize" not in fig_args:
fig_args["figsize"] = (12, 6)
renderer.render(
state=state,
ds=ds,
params=(x, y, hue),
param_types=(x_type, y_type, hue_type),
data=data,
fig_args=fig_args,
chart_args=chart_args,
)
def _prepare_chart_args(self, df, x, x_type, y, y_type, hue) -> Tuple[Dict[str, Any], pd.DataFrame, bool]:
chart_args = {"x": x, "y": y, "hue": hue, **self._kwargs}
chart_args = {k: v for k, v in chart_args.items() if v is not None}
data = df
is_single_var = False
if x is not None and y is None and hue is None:
is_single_var = True
if x_type == "numeric":
data = df[x]
chart_args.pop("x")
elif y is not None and x is None and hue is None:
is_single_var = True
if y_type == "numeric":
data = df[y]
chart_args.pop("y")
return chart_args, data, is_single_var
def _convert_categoricals_to_objects(self, df, x, x_type, y, y_type, hue, hue_type):
# convert to categoricals for plots
for col, typ in zip([x, y, hue], [x_type, y_type, hue_type]):
if typ == "category":
df[col] = df[col].astype("object")
return df
def _swap_y_and_hue_if_necessary(self, x_type, y, y_type, hue, hue_type):
# swap y <-> hue when category vs category is provided and no hue is specified
if (x_type is not None) and y_type == "category" and hue_type is None:
hue, hue_type = y, y_type
y, y_type = None, None
return y, y_type, hue, hue_type
def _get_value_and_type(
self, ds: str, df: pd.DataFrame, state: AnalysisState, interaction_features: Dict[str, Any], param: str
) -> Tuple[Any, Optional[str]]:
col = interaction_features.get(param, None)
value_type = self._map_raw_type_to_feature_type(
col, state.raw_type[ds].get(col, None), df, self.numeric_as_categorical_threshold
)
return col, value_type
def _get_chart_renderer(
self, x_type: Optional[str], y_type: Optional[str], hue_type: Optional[str]
) -> Optional[Type[_AbstractFeatureInteractionPlotRenderer]]:
types = {
("numeric", None, None): self._HistPlotRenderer,
("category", None, None): self._CountPlotRenderer,
(None, "category", None): self._CountPlotRenderer,
("category", None, "category"): self._CountPlotRenderer,
(None, "category", "category"): self._CountPlotRenderer,
("numeric", None, "category"): self._HistPlotRenderer,
(None, "numeric", "category"): self._HistPlotRenderer,
("category", "category", None): self._BarPlotRenderer,
("category", "category", "category"): self._BarPlotRenderer,
("category", "numeric", None): self._BoxPlotRenderer,
("numeric", "category", None): self._KdePlotRenderer,
("category", "numeric", "category"): self._BoxPlotRenderer,
("numeric", "category", "category"): self._KdePlotRenderer,
("numeric", "numeric", None): self._RegPlotRenderer,
("numeric", "numeric", "category"): self._ScatterPlotRenderer,
("datetime", "numeric", None): self._LinePlotRenderer,
}
return types.get((x_type, y_type, hue_type), None)
def _map_raw_type_to_feature_type(
self, col: str, raw_type: str, df: pd.DataFrame, numeric_as_categorical_threshold: int = 20
) -> Optional[str]:
if col is None:
return None
elif df[col].nunique() <= numeric_as_categorical_threshold:
return "category"
elif raw_type in [R_INT, R_FLOAT]:
return "numeric"
elif raw_type in [R_DATETIME]:
return "datetime"
elif raw_type in [R_OBJECT, R_CATEGORY, R_BOOL]:
return "category"
else:
return None
class _HistPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args, num_point_to_fit=200):
x = params[0]
fitted_distributions_present = (
("distributions_fit" in state)
and (param_types == ("numeric", None, None)) # (x, y, hue)
and (state.distributions_fit[ds].get(x, None) is not None)
)
if fitted_distributions_present:
chart_args["stat"] = "density"
sns.histplot(ax=ax, data=data, **chart_args)
if fitted_distributions_present: # types for x, y, hue
dists = state.distributions_fit[ds][x]
x_min, x_max = ax.get_xlim()
xs = np.linspace(x_min, x_max, num_point_to_fit)
for dist, v in dists.items():
_dist = getattr(stats, dist)
ax.plot(
xs,
_dist.pdf(xs, *v["param"]),
ls="--",
label=f'{dist}: pvalue {v["pvalue"]:.2f}',
)
ax.set_xlim(x_min, x_max) # set the limits back to the ones of the distplot
plt.legend()
class _KdePlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
chart_args.pop("fill", None)
chart = sns.kdeplot(ax=ax, data=data, **chart_args)
plt.setp(chart.get_xticklabels(), rotation=90)
class _BoxPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
chart = sns.boxplot(ax=ax, data=data, **chart_args)
plt.setp(chart.get_xticklabels(), rotation=90)
class _CountPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
chart = sns.countplot(ax=ax, data=data, **chart_args)
plt.setp(chart.get_xticklabels(), rotation=90)
for container in ax.containers:
ax.bar_label(container)
class _BarPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
chart_args["errorbar"] = None # Don't show ci ticks
chart = sns.barplot(ax=ax, data=data, **chart_args)
plt.setp(chart.get_xticklabels(), rotation=90)
class _ScatterPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
sns.scatterplot(ax=ax, data=data, **chart_args)
class _RegPlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
sns.regplot(ax=ax, data=data, **chart_args)
class _LinePlotRenderer(_AbstractFeatureInteractionPlotRenderer):
def _render(self, state, ds, params, param_types, ax, data, chart_args):
sns.lineplot(ax=ax, data=data, **chart_args)