"""
Smart explainer module
"""
import copy
import logging
import shutil
import tempfile
import numpy as np
import pandas as pd
import shapash.explainer.smart_predictor
from shapash.backend import BaseBackend, get_backend_cls_from_name
from shapash.backend.shap_backend import get_shap_interaction_values
from shapash.manipulation.select_lines import keep_right_contributions
from shapash.manipulation.summarize import create_grouped_features_values
from shapash.report import check_report_requirements
from shapash.style.style_utils import colors_loading, select_palette
from shapash.utils.check import (
check_additional_data,
check_columns_order,
check_features_name,
check_label_dict,
check_model,
check_postprocessing,
check_y,
)
from shapash.utils.custom_thread import CustomThread
from shapash.utils.explanation_metrics import find_neighbors, get_distance, get_min_nb_features, shap_neighbors
from shapash.utils.io import load_pickle, save_pickle
from shapash.utils.model import predict, predict_error, predict_proba
from shapash.utils.transform import apply_postprocessing, handle_categorical_missing, inverse_transform
from shapash.utils.utils import get_host_name
from shapash.webapp.smart_app import SmartApp
from .smart_plotter import SmartPlotter
logging.basicConfig(level=logging.INFO)
[docs]class SmartExplainer:
"""
The main class of the Shapash library, designed to make machine learning model
results more interpretable and understandable.
`SmartExplainer` links together the model, encoders, datasets, predictions,
and label dictionaries. It provides a variety of methods to visualize and
analyze model explanations both in notebooks and in the Shapash WebApp.
Parameters
----------
model : object
The model to be explained. Used for consistency checks and, in some cases,
to compute `predict` and `predict_proba` values.
backend : str or shapash.backend.BaseBackend, default='shap'
Defines the backend used to compute feature contributions and importances.
Options:
- `'shap'`: use SHAP as backend.
- `'lime'`: use LIME as backend.
You can also pass a custom backend class that inherits from
`shapash.backend.BaseBackend`.
preprocessing : category_encoders, ColumnTransformer, list, dict, optional (default: None)
--> Differents types of preprocessing are available:
- A single category_encoders (OrdinalEncoder/OnehotEncoder/BaseNEncoder/BinaryEncoder/TargetEncoder)
- A single ColumnTransformer with scikit-learn encoding or category_encoders transformers
- A list with multiple category_encoders with optional (dict, list of dict)
- A list with a single ColumnTransformer with optional (dict, list of dict)
- A dict
- A list of dict
postprocessing : dict, optional (default: None)
Dictionnary of postprocessing modifications to apply in x_init dataframe.
Dictionnary with feature names as keys (or number, or well labels referencing to features names),
which modifies dataset features by features.
--> Different types of postprocessing are available, but the syntax is this one:
One key by features, 5 different types of modifications:
features_groups : dict, optional (default: None)
Dictionnary containing features that should be grouped together. This option allows
to compute and display the contributions and importance of this group of features.
Features that are grouped together will still be displayed in the webapp when clicking
on a group.
>>> {
‘feature1’ : { ‘type’ : ‘prefix’, ‘rule’ : ‘age: ‘ },
‘feature2’ : { ‘type’ : ‘suffix’, ‘rule’ : ‘$/week ‘ },
‘feature3’ : { ‘type’ : ‘transcoding’, ‘rule‘: { ‘code1’ : ‘single’, ‘code2’ : ‘married’}},
‘feature4’ : { ‘type’ : ‘regex’ , ‘rule‘: { ‘in’ : ‘AND’, ‘out’ : ‘ & ‘ }},
‘feature5’ : { ‘type’ : ‘case’ , ‘rule‘: ‘lower’‘ }
}
Only one transformation by features is possible.
features_groups : dict, optional
Groups of features to be aggregated together in plots and importance
computations. Each key defines a group name, and its value is a list of
feature names.
Example:
>>> {
... 'feature_group_1': ['feature3', 'feature7', 'feature24'],
... 'feature_group_2': ['feature1', 'feature12']
... }
features_dict : dict, optional
Mapping from technical feature names to domain-specific (readable) names.
label_dict : dict, optional
Mapping from numeric labels to human-readable class names (for classification tasks).
title_story : str, optional
Custom title used in visualizations and reports. Default is empty.
palette_name : str, optional
Name of the color palette used for visualizations (see the `style` folder for options).
colors_dict : dict, optional
Dictionary containing the full color palette configuration.
Can be used to override default plot colors.
**backend_kwargs : dict
Additional keyword arguments passed to the backend.
Attributes
----------
data: dict
Data dictionary has 3 entries. Each key returns a pd.DataFrame (regression) or a list of pd.DataFrame
(classification - The length of the lists is equivalent to the number of labels).
All pd.DataFrame have she same shape (n_samples, n_features).
For the regression case, data that should be regarded as a single array
of size (n_samples, n_features, 3).
data['contrib_sorted']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
Contains local contributions of the prediction set, with common line index.
Columns are 'contrib_1', 'contrib_2', ... and contains the top contributions
for each line from left to right. In multi-class problems, this is a list of
contributions, one for each class.
data['var_dict']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
Must contain only ints. It gives, for each line, the list of most import features
regarding the local decomposition. In order to save space, columns are denoted by
integers, the conversion being done with the columns_dict member. In multi-class
problems, this is a list of dataframes, one for each class.
data['x_sorted']: pandas.DataFrame (regression) or list of pandas.DataFrame (classification)
It gives, for each line, the list of most important features values regarding the local
decomposition. These values can only be understood with respect to data['var_dict']
backend_name : str
Name of the backend if specified as a string.
x_encoded : pandas.DataFrame
Preprocessed dataset used by the model.
x_init : pandas.DataFrame
Inverse-transformed dataset (after preprocessing) with optional postprocessing.
x_contrib_plot : pandas.DataFrame
Inverse-transformed dataset without postprocessing (used for plots).
y_pred : pandas.DataFrame
Model predictions.
contributions : pandas.DataFrame or list
Local feature contributions. Aggregated if preprocessing expands features
(e.g., one-hot encoding).
features_dict : dict
Mapping from technical feature names to domain names.
inv_features_dict : dict
Reverse mapping of `features_dict`.
label_dict : dict
Mapping from numeric labels to class names.
inv_label_dict : dict
Reverse mapping of `label_dict`.
columns_dict : dict
Mapping from feature index to technical feature name.
plot : SmartPlotter
Object providing access to all plotting functions.
model : object
The model being explained.
features_desc : dict
Number of unique values per feature in `x_init`.
features_imp : pandas.Series or list
Computed feature importance values.
local_neighbors : dict
Data displayed in local neighbor plots (normalized SHAP values, etc.).
features_stability : dict
Data used for stability plots, including:
- `'amplitude'`: average contribution values for selected instances.
- `'stability'`: metric assessing stability across neighborhoods.
preprocessing : category_encoders object, ColumnTransformer, list, or dict
Preprocessing transformations applied to raw input data.
postprocessing : dict
Postprocessing rules applied after inverse preprocessing.
y_target : pandas.Series or pandas.DataFrame, optional
True target values.
Example
-------
>>> xpl = SmartExplainer(model, features_dict=featd, label_dict=labeld)
>>> xpl.compile(x=x_encoded, y_target=y)
>>> xpl.plot.features_importance()
"""
def __init__(
self,
model,
backend="shap",
preprocessing=None,
postprocessing=None,
features_groups=None,
features_dict=None,
label_dict=None,
title_story: str = None,
palette_name=None,
colors_dict=None,
**backend_kwargs,
):
if features_dict is not None and not isinstance(features_dict, dict):
raise ValueError(
"""
features_dict must be a dict
"""
)
if label_dict is not None and isinstance(label_dict, dict) is False:
raise ValueError(
"""
label_dict must be a dict
"""
)
self.model = model
self.preprocessing = preprocessing
self.backend_name = None
if isinstance(backend, str):
self.backend_name = backend
elif isinstance(backend, BaseBackend):
self.backend = backend
if backend.preprocessing is None and self.preprocessing is not None:
self.backend.preprocessing = self.preprocessing
else:
raise NotImplementedError(f"Unknown backend : {backend}")
self.backend_kwargs = backend_kwargs
self.features_dict = dict() if features_dict is None else copy.deepcopy(features_dict)
self.label_dict = label_dict
self.title_story = title_story if title_story is not None else ""
self.palette_name = palette_name if palette_name else "default"
self.colors_dict = copy.deepcopy(select_palette(colors_loading(), self.palette_name))
if colors_dict is not None:
self.colors_dict.update(colors_dict)
self.plot = SmartPlotter(self, self.colors_dict)
self._case, self._classes = check_model(self.model)
self.postprocessing = postprocessing
self.check_label_dict()
if self.label_dict:
self.inv_label_dict = {v: k for k, v in self.label_dict.items()}
self.features_groups = features_groups
self.local_neighbors = None
self.features_stability = None
self.features_compacity = None
self.contributions = None
self.explain_data = None
self.features_imp = None
[docs] def compile(
self,
x,
contributions=None,
y_pred=None,
proba_values=None,
y_target=None,
columns_order=None,
additional_data=None,
additional_features_dict=None,
):
"""
Prepare and structure all data needed for interpreting the model and its predictions.
The `compile` method is the first essential step to make your model explainable
with Shapash. It organizes the model’s inputs, outputs, and contributions into
a consistent format, applies inverse preprocessing, and computes all elements
required for visualization and summaries.
Depending on dataset size and backend, this step may take some time.
Parameters
----------
x : pandas.DataFrame
Prediction dataset — the same data seen by the end user.
It should correspond to the **raw prediction input** (post-preprocessing).
Shapash will use this dataset to compute and align explanations.
contributions : pandas.DataFrame, numpy.ndarray, or list, optional
Local feature contributions for each sample.
- If a `DataFrame`, its index and columns must match those of `x`.
- If a `numpy.ndarray`, Shapash will automatically generate the corresponding
index and column names based on `x`.
- In multi-class settings, provide a list of contributions (one per class).
y_pred : pandas.Series or pandas.DataFrame, optional
Model predictions.
Must have the same index as `x_init`.
Useful for customizing predicted values, for example when applying
a custom threshold in classification tasks.
proba_values : pandas.Series or pandas.DataFrame, optional
Prediction probabilities.
Must have the same index as `x_init`.
Useful for visualizations and for comparing probabilities across classes.
y_target : pandas.Series or pandas.DataFrame, optional
True target values used for comparison or performance display.
Must have the same index as `x_init`.
columns_order : list or str, optional
Defines the display order of columns in the dataset.
- If a **list** is provided, it specifies the exact order of columns.
Any columns not included in the list will be added automatically.
- If set to `'additional_data_first'`, all additional columns are placed first.
- If set to `'additional_data_last'`, all additional columns are placed last.
This option helps control column order in the Shapash WebApp and SmartApp.
additional_data : pandas.DataFrame, optional
Additional features not used by the model but relevant for visualization or filtering
in the WebApp.
Must have the same index as `x_init`.
additional_features_dict : dict, optional
Mapping of additional feature names (technical names) to user-friendly
domain names, used to improve readability in plots and dashboards.
Must have the same index as `x_init`.
Example
-------
>>> xpl.compile(x=x_test)
>>> xpl.plot.features_importance()
"""
if isinstance(self.backend_name, str):
backend_cls = get_backend_cls_from_name(self.backend_name)
self.backend = backend_cls(
model=self.model, preprocessing=self.preprocessing, masker=x, **self.backend_kwargs
)
self.x_encoded = handle_categorical_missing(x)
x_init = inverse_transform(self.x_encoded, self.preprocessing)
self.x_init = handle_categorical_missing(x_init)
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if (self.y_pred is None) and (hasattr(self.model, "predict")):
self.predict()
self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if (self._case == "classification") and (self.proba_values is None) and (hasattr(self.model, "predict_proba")):
self.predict_proba()
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
self.prediction_error = predict_error(
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
)
self._get_contributions_from_backend_or_user(x, contributions)
self.check_contributions()
self.columns_dict = {i: col for i, col in enumerate(self.x_init.columns)}
self.check_features_dict()
self.inv_features_dict = {v: k for k, v in self.features_dict.items()}
self._apply_all_postprocessing_modifications()
self.data = self.state.assign_contributions(self.state.rank_contributions(self.contributions, self.x_init))
self.features_desc = dict(self.x_init.nunique())
if self.features_groups is not None:
self._compile_features_groups(self.features_groups)
self.additional_features_dict = (
dict()
if additional_features_dict is None
else self._compile_additional_features_dict(additional_features_dict)
)
self.additional_data = self._compile_additional_data(additional_data)
self.columns_order = self._compile_columns_order(columns_order)
self.plot._tuning_round_digit()
def _get_contributions_from_backend_or_user(self, x, contributions):
# Computing contributions using backend
if contributions is None:
self.explain_data = self.backend.run_explainer(x=x)
self.contributions = self.backend.get_local_contributions(x=x, explain_data=self.explain_data)
else:
self.explain_data = contributions
self.contributions = self.backend.format_and_aggregate_local_contributions(
x=x,
contributions=contributions,
)
self.state = self.backend.state
def _apply_all_postprocessing_modifications(self):
postprocessing = self.modify_postprocessing(self.postprocessing)
check_postprocessing(self.x_init, postprocessing)
self.postprocessing_modifications = self.check_postprocessing_modif_strings(postprocessing)
self.postprocessing = postprocessing
if self.postprocessing_modifications:
self.x_contrib_plot = copy.deepcopy(self.x_init)
self.x_init = self.apply_postprocessing(postprocessing)
def _compile_features_groups(self, features_groups):
"""
Performs required computations for groups of features.
"""
if self.backend.support_groups is False:
raise AssertionError(f"Selected backend ({self.backend.name}) does not support groups of features.")
# Compute contributions for groups of features
self.contributions_groups = self.state.compute_grouped_contributions(self.contributions, features_groups)
self.features_imp_groups = None
# Update features dict with groups names
self._update_features_dict_with_groups(features_groups=features_groups)
# Compute t-sne projections for groups of features
self.x_init_groups = create_grouped_features_values(
x_init=self.x_init,
x_encoded=self.x_encoded,
preprocessing=self.preprocessing,
features_groups=self.features_groups,
features_dict=self.features_dict,
how="dict_of_values",
)
# Compute data attribute for groups of features
self.data_groups = self.state.assign_contributions(
self.state.rank_contributions(self.contributions_groups, self.x_init_groups)
)
self.columns_dict_groups = {i: col for i, col in enumerate(self.x_init_groups.columns)}
def _compile_additional_features_dict(self, additional_features_dict):
"""
Performs required computations for additional features dict.
"""
if not isinstance(additional_features_dict, dict):
raise ValueError(
"""
additional_features_dict must be a dict
"""
)
additional_features_dict = {f"_{key}": f"_{value}" for key, value in additional_features_dict.items()}
return additional_features_dict
def _compile_additional_data(self, additional_data):
"""
Performs required computations for additional data.
"""
if additional_data is not None:
check_additional_data(self.x_init, additional_data)
for feature in additional_data.columns:
if feature in self.features_dict.keys() and feature not in self.columns_dict.values():
self.additional_features_dict[f"_{feature}"] = f"_{self.features_dict[feature]}"
del self.features_dict[feature]
additional_data = additional_data.add_prefix("_")
for feature in set(list(additional_data.columns)) - set(self.additional_features_dict):
self.additional_features_dict[feature] = feature
return additional_data
def _compile_columns_order(self, columns_order):
"""
Performs required computations for ordering data.
"""
if isinstance(columns_order, list):
check_columns_order(columns_order)
# Prefix column name with "_" if it's listed in additional_features_dict
columns_order = [f"_{col}" if f"_{col}" in self.additional_features_dict else col for col in columns_order]
x_cols = set(self.x_encoded.columns)
additional_cols = set(self.additional_features_dict)
columns_order_set = set(columns_order)
# Check for missing or unexpected columns
missing_cols = x_cols - columns_order_set
extra_cols = columns_order_set - x_cols - additional_cols
if missing_cols:
raise ValueError(f"The following columns are missing from columns_order: {missing_cols}")
if extra_cols:
raise ValueError(
f"The following columns in columns_order do not exist in x or additional data: {extra_cols}"
)
return columns_order
def define_style(self, palette_name=None, colors_dict=None):
"""
Set the color set to use in plots.
"""
if palette_name is None and colors_dict is None:
raise ValueError("At least one of palette_name or colors_dict parameters must be defined")
new_palette_name = palette_name or self.palette_name
new_colors_dict = copy.deepcopy(select_palette(colors_loading(), new_palette_name))
if colors_dict is not None:
new_colors_dict.update(colors_dict)
self.colors_dict.update(new_colors_dict)
self.plot.define_style_attributes(colors_dict=self.colors_dict)
[docs] def add(
self,
y_pred=None,
proba_values=None,
y_target=None,
label_dict=None,
features_dict=None,
title_story: str = None,
columns_order=None,
additional_data=None,
additional_features_dict=None,
):
"""
Add or update metadata and outputs without recompiling the explainer.
The `add` method lets users attach or modify supplementary information such as
predictions, label or feature dictionaries, and display options **without**
rerunning the full `compile()` process (which can be time-consuming for large datasets).
It can be used to:
- Add or update `y_pred` (used to color plots or export results).
- Add or update `label_dict` and `features_dict` for clearer labels in visualizations.
- Include additional data or adjust column display order in the WebApp.
Parameters
----------
y_pred : pandas.Series or pandas.DataFrame, optional
Model predictions (one column only).
Must have the same index as `x_init`.
Used in plots (e.g., to color scatter plots) and in export methods like `to_pandas()`.
proba_values : pandas.Series or pandas.DataFrame, optional
Prediction probabilities (one column only).
Must have the same index as `x_init`.
Useful for visualizations or probabilistic outputs.
y_target : pandas.Series or pandas.DataFrame, optional
True target values (one column only).
Must have the same index as `x_init`.
Used for comparison and performance-oriented visualizations.
label_dict : dict, optional
Mapping of integer labels to domain names (for classification targets).
Enables clearer class naming in plots and tables.
features_dict : dict, optional
Mapping of technical feature names to human-readable (domain) names.
Improves interpretability of plots and exported data.
title_story : str, optional
Custom title for reports or visualizations.
Default is empty.
columns_order : list or str, optional
Defines the display order of columns in the dataset.
- If a **list** is provided, it specifies the exact order of columns.
Columns not included will be appended automatically.
- If set to `'additional_data_first'`, additional columns appear first.
- If set to `'additional_data_last'`, additional columns appear last.
Especially useful for controlling display order in the Shapash SmartApp.
additional_data : pandas.DataFrame, optional
Extra dataset containing features outside the model.
Must have the same index as `x_init`.
Useful for filtering and enrichment in the Shapash WebApp.
additional_features_dict : dict, optional
Dictionary mapping technical feature names to human-readable names
for columns in `additional_data`.
Example
-------
>>> # Add predictions and friendly feature names after compiling
>>> xpl.add(y_pred=preds, features_dict=feat_dict)
>>> xpl.plot.local_plot(index=5)
"""
if y_pred is not None:
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if proba_values is not None:
self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if y_target is not None:
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
if hasattr(self, "y_target") and self.y_target is not None:
self.prediction_error = predict_error(
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
)
if label_dict is not None:
if isinstance(label_dict, dict) is False:
raise ValueError(
"""
label_dict must be a dict
"""
)
self.label_dict = label_dict
self.check_label_dict()
self.inv_label_dict = {v: k for k, v in self.label_dict.items()}
if features_dict is not None:
if isinstance(features_dict, dict) is False:
raise ValueError(
"""
features_dict must be a dict
"""
)
self.features_dict = features_dict
self.check_features_dict()
self.inv_features_dict = {v: k for k, v in self.features_dict.items()}
if title_story is not None:
self.title_story = title_story
if additional_features_dict is not None:
self.additional_features_dict = self._compile_additional_features_dict(additional_features_dict)
if additional_data is not None:
self.additional_data = self._compile_additional_data(additional_data)
if columns_order is not None:
self.columns_order = self._compile_columns_order(columns_order)
def get_interaction_values(self, n_samples_max=None, selection=None):
"""
Compute SHAP interaction values for the encoded dataset.
This method calculates pairwise SHAP interaction effects between features
for each sample in `x_encoded`. It is only available when using a backend
based on `TreeExplainer` (i.e., for tree-based models such as LightGBM,
XGBoost, or CatBoost).
For more details, see the official Tree SHAP paper:
https://arxiv.org/pdf/1802.03888.pdf
Parameters
----------
n_samples_max : int, optional
Maximum number of samples to compute interaction values for.
If provided, the computation will be limited to this number of samples,
selected randomly or according to the backend implementation.
selection : list of int, optional
List of specific sample indices for which to compute interactions.
Useful to focus on a subset of the dataset rather than the entire `x_encoded`.
Returns
-------
numpy.ndarray
Array of SHAP interaction values with shape `(n_samples, n_features, n_features)`.
Each entry `[i, j, k]` represents the interaction strength between features `j`
and `k` for sample `i`.
"""
x = copy.deepcopy(self.x_encoded)
if selection:
x = x.loc[selection]
if hasattr(self, "x_interaction"):
if self.x_interaction.equals(x[:n_samples_max]):
return self.interaction_values
self.x_interaction = x[:n_samples_max]
self.interaction_values = get_shap_interaction_values(self.x_interaction, self.backend.explainer)
return self.interaction_values
def check_postprocessing_modif_strings(self, postprocessing=None):
"""
Check whether postprocessing transformations will convert numeric values to strings.
This method inspects the provided `postprocessing` configuration and determines
if any transformation rule would change a numerical feature into a string representation
(e.g., by adding prefixes, suffixes, or other text-based modifications).
Parameters
----------
postprocessing : dict, optional
Dictionary of postprocessing transformations to apply.
Keys correspond to feature names, and values define transformation rules.
Returns
-------
bool
`True` if at least one numeric feature will be converted to string,
otherwise `False`.
"""
modif = False
if postprocessing is not None:
for key in postprocessing.keys():
dict_postprocess = postprocessing[key]
if dict_postprocess["type"] in {"prefix", "suffix"} and pd.api.types.is_numeric_dtype(self.x_init[key]):
modif = True
return modif
def modify_postprocessing(self, postprocessing=None):
"""
Adjust the postprocessing dictionary so that all keys reference actual feature names.
This method ensures that postprocessing rules are aligned with the real feature names
used in the dataset. If the provided dictionary uses alternative identifiers
(such as column indices or encoded names), they are converted into the corresponding
feature names using `columns_dict` or `inv_features_dict`.
Parameters
----------
postprocessing : dict, optional
Dictionary of postprocessing transformations to adjust.
Keys may be feature names, indices, or label references.
Returns
-------
dict
Modified postprocessing dictionary, where all keys correspond directly
to real feature names while preserving the original transformation rules.
"""
if postprocessing:
new_dic = dict()
for key in postprocessing.keys():
if key in self.features_dict:
new_dic[key] = postprocessing[key]
elif key in self.columns_dict.keys():
new_dic[self.columns_dict[key]] = postprocessing[key]
elif key in self.inv_features_dict:
new_dic[self.inv_features_dict[key]] = postprocessing[key]
else:
raise ValueError(f"Feature name '{key}' not found in the dataset.")
return new_dic
def apply_postprocessing(self, postprocessing=None):
"""
Apply postprocessing transformations to the `x_init` DataFrame, if defined.
This method updates `x_init` according to the transformation rules specified
in the `postprocessing` dictionary. If no postprocessing is provided,
the original `x_init` is returned unchanged.
Parameters
----------
postprocessing : dict, optional
Dictionary of postprocessing transformations to apply to `x_init`.
Keys correspond to feature names, and values define the transformation rules.
Returns
-------
pandas.DataFrame
The modified `x_init` DataFrame if postprocessing rules are applied,
otherwise the unmodified `x_init`.
"""
if postprocessing:
return apply_postprocessing(self.x_init, postprocessing)
else:
return self.x_init
def check_label_dict(self):
"""
Check if label_dict and model _classes match
"""
if self._case != "regression":
return check_label_dict(self.label_dict, self._case, self._classes)
def check_features_dict(self):
"""
Synchronize features_dict with dataset columns:
- Remove features not present in dataset
- Add missing dataset features to features_dict
"""
dataset_features = set(self.columns_dict.values())
current_features = set(self.features_dict.keys())
# Remove features not present in dataset
for feature in current_features - dataset_features:
self.features_dict.pop(feature, None)
# Add features present in dataset but missing in features_dict
for feature in dataset_features - current_features:
self.features_dict[feature] = feature
def _update_features_dict_with_groups(self, features_groups):
"""
Add groups into features dict and inv_features_dict if not present.
"""
for group_name in features_groups.keys():
self.features_desc[group_name] = 1000
if group_name not in self.features_dict.keys():
self.features_dict[group_name] = group_name
self.inv_features_dict[group_name] = group_name
def check_contributions(self):
"""
Check if contributions and prediction set match in terms of shape and index.
"""
if not self.state.check_contributions(self.contributions, self.x_init):
raise ValueError(
"""
Prediction set and contributions should have exactly the same number of lines
and number of columns. the order of the columns must be the same
Please check x, contributions and preprocessing arguments.
"""
)
def check_label_name(self, label, origin=None):
"""
Validate and convert a label name into its corresponding integer identifier.
If the provided label is already an integer, it is returned unchanged.
If it is a string corresponding to a class name, the method converts it
into the appropriate integer label using the label dictionary.
An error is raised if the label cannot be recognized.
Parameters
----------
label : int or str
Label identifier, provided either as an integer (class index)
or as a string (human-readable class name).
origin : {'num', 'code', 'value', None}, optional
Specifies the form of the input label:
- `'num'`: integer class index
- `'code'`: internal label code
- `'value'`: business or display name
- `None`: automatically inferred (default)
Returns
-------
tuple
A tuple containing:
- `label_num` : int — numerical class index
- `label_code` : object — internal class code used by the model
- `label_value` : str — human-readable class name
"""
if origin is None:
if label in self._classes:
origin = "code"
elif self.label_dict is not None and label in self.label_dict.values():
origin = "value"
elif isinstance(label, int) and label in range(-1, len(self._classes)):
origin = "num"
try:
if origin == "num":
label_num = label
label_code = self._classes[label]
label_value = self.label_dict[label_code] if self.label_dict else label_code
elif origin == "code":
label_code = label
label_num = self._classes.index(label)
label_value = self.label_dict[label_code] if self.label_dict else label_code
elif origin == "value":
label_code = self.inv_label_dict[label]
label_num = self._classes.index(label_code)
label_value = label
else:
raise ValueError
except ValueError:
raise Exception({"message": "Origin must be 'num', 'code' or 'value'."})
except Exception:
raise Exception({"message": f"Label ({label}) not found for origin ({origin})"})
return label_num, label_code, label_value
def check_features_name(self, features, use_groups=False):
"""
Validate and convert feature names or IDs into their corresponding column indices.
This method ensures that the provided list of features is aligned with
the internal column indexing used in Shapash. It supports both
technical feature names and business (domain) names, as defined in
`columns_dict` or `features_dict`.
Parameters
----------
features : list
List of feature identifiers, where each element can be either:
- an integer (column ID), or
- a string (technical or business feature name).
use_groups : bool, optional
If True, the method also resolves feature groups defined in
`features_groups`. Default is False.
Returns
-------
list of int
List of column indices corresponding to the input features,
compatible with `var_dict`.
"""
columns_dict = self.columns_dict if use_groups is False else self.columns_dict_groups
return check_features_name(columns_dict, self.features_dict, features)
def check_attributes(self, attribute):
"""
Verify that the SmartExplainer instance contains the specified attribute.
This method checks whether the given attribute exists within the
current `SmartExplainer` instance and returns its content if found.
Parameters
----------
attribute : str
Name of the attribute to check.
Returns
-------
object
The value of the specified attribute from the `SmartExplainer` instance.
Raises
------
ValueError
If the specified attribute does not exist in the current explainer.
"""
if not hasattr(self, attribute):
raise ValueError(f"The attribute '{attribute}' does not exist in this SmartExplainer instance.")
return self.__dict__[attribute]
[docs] def filter(self, features_to_hide=None, threshold=None, positive=None, max_contrib=None, display_groups=None):
"""
Apply filtering rules to summarize local explainability results.
The `filter` method allows users to control which feature contributions
are displayed or hidden when visualizing local explanations.
It is typically used in combination with the `local_plot` method of
`SmartPlotter` to display a filtered local contribution bar chart.
For detailed examples, see the **Local Plot** tutorial in the Shapash documentation.
Parameters
----------
features_to_hide : list of str, optional
List of feature names to hide from the visualization.
These can be individual feature names or group names if
`display_groups=True`.
threshold : float, optional
Absolute value threshold below which contributions are hidden.
For example, setting `threshold=0.01` hides all features with
contribution magnitudes smaller than 0.01.
positive : bool, optional
Defines whether to hide contributions by sign:
- If `True`, hides negative contributions.
- If `False`, hides positive contributions.
- If `None` (default), all contributions are displayed.
max_contrib : int, optional
Maximum number of contributions to display.
Only the top `max_contrib` features (by absolute contribution)
will be shown.
display_groups : bool, optional
If `True`, feature groups defined in `features_groups` are displayed
and filtered together.
If `False`, only individual features are considered.
By default, this is automatically set to `True` if
feature groups are defined.
Notes
-----
- The filtering configuration is stored in `self.mask_params`.
- The resulting filtered contributions are available in
`self.masked_contributions`.
Example
-------
>>> # Hide specific features and small contributions
>>> xpl.filter(features_to_hide=['Age', 'Gender'], threshold=0.01, max_contrib=10)
>>> xpl.plot.local_plot(index=5)
"""
display_groups = True if (display_groups is not False and self.features_groups is not None) else False
if display_groups:
data = self.data_groups
else:
data = self.data
mask = [self.state.init_mask(data["contrib_sorted"], True)]
if features_to_hide:
mask.append(
self.state.hide_contributions(
data["var_dict"],
features_list=self.check_features_name(features_to_hide, use_groups=display_groups),
)
)
if threshold:
mask.append(self.state.cap_contributions(data["contrib_sorted"], threshold=threshold))
if positive is not None:
mask.append(self.state.sign_contributions(data["contrib_sorted"], positive=positive))
self.mask = self.state.combine_masks(mask)
if max_contrib:
self.mask = self.state.cutoff_contributions(self.mask, max_contrib=max_contrib)
self.masked_contributions = self.state.compute_masked_contributions(data["contrib_sorted"], self.mask)
self.mask_params = {
"features_to_hide": features_to_hide,
"threshold": threshold,
"positive": positive,
"max_contrib": max_contrib,
}
[docs] def save(self, path):
"""
Save the SmartExplainer object to disk as a pickle file.
This method serializes the current `SmartExplainer` instance and saves it
to a `.pkl` file. It allows users to reload an explainer later without
recompiling, which is especially useful for large datasets or models.
Parameters
----------
path : str
Destination file path where the pickle file will be saved.
Notes
-----
- The `smartapp` attribute is removed before saving to avoid serialization issues.
- The saved object can be reloaded using the `load` method.
Example
-------
>>> xpl.save("path_to_file/xpl.pkl")
>>> xpl_loaded = SmartExplainer.load("path_to_file/xpl.pkl")
"""
if hasattr(self, "smartapp"):
self.smartapp = None
save_pickle(self, path)
[docs] @classmethod
def load(cls, path):
"""
Load a previously saved SmartExplainer object from a pickle file.
This class method restores a `SmartExplainer` instance that was saved
using the `save` method. It allows users to quickly reload a compiled
explainer without repeating the full preprocessing and explanation steps.
Parameters
----------
path : str
File path to the pickle file containing the saved `SmartExplainer` object.
Returns
-------
SmartExplainer
A reloaded `SmartExplainer` instance identical to the one saved on disk.
Raises
------
ValueError
If the provided file does not contain a valid `SmartExplainer` object.
Example
-------
>>> xpl = SmartExplainer.load("path_to_file/xpl.pkl")
>>> xpl.plot.features_importance()
"""
xpl = load_pickle(path)
if isinstance(xpl, SmartExplainer):
smart_explainer = cls(model=xpl.model)
smart_explainer.__dict__.update(xpl.__dict__)
return smart_explainer
else:
raise ValueError("The provided file does not contain a SmartExplainer object.")
def predict_proba(self):
"""
Compute and store prediction probabilities for each sample in `x_encoded`.
This method applies the model’s `predict_proba` function to the encoded
dataset (`x_encoded`) and saves the resulting probability values in
`self.proba_values`.
It is typically used for classification models to display or analyze
predicted probabilities in visualizations or summaries.
Returns
-------
None
The computed probabilities are stored in the `proba_values` attribute.
Example
-------
>>> xpl.predict_proba()
>>> xpl.proba_values.head()
"""
self.proba_values = predict_proba(self.model, self.x_encoded, self._classes)
def predict(self):
"""
Compute and store model predictions for each sample in `x_encoded`.
This method applies the model’s `predict` function to the encoded dataset
(`x_encoded`) and saves the resulting predictions in the `y_pred` attribute.
If target values (`y_target`) are available, it also computes and stores
the prediction error in `prediction_error`.
Returns
-------
None
The computed predictions are stored in the `y_pred` attribute.
If available, prediction errors are stored in `prediction_error`.
Example
-------
>>> xpl.predict()
>>> xpl.y_pred.head()
>>> xpl.prediction_error
"""
self.y_pred = predict(self.model, self.x_encoded)
if hasattr(self, "y_target"):
self.prediction_error = predict_error(
self.y_target, self.y_pred, self._case, proba_values=self.proba_values, classes=self._classes
)
[docs] def to_pandas(
self,
features_to_hide=None,
threshold=None,
positive=None,
max_contrib=None,
proba=False,
use_groups=None,
):
"""
Export a summarized view of local explainability results as a pandas DataFrame.
The `to_pandas` method summarizes the local contributions of each feature
for every sample, returning a DataFrame that combines predictions, probabilities
(if applicable), and the top feature contributions.
If no filtering parameters are provided, the method automatically reuses
the configuration from the most recent call to the `filter` method.
In classification tasks, this summary corresponds to the predicted values
specified by the user (using either `compile()` or `add()`).
You can also choose to include prediction probabilities using the `proba` parameter.
There are two main usage modes in classification:
1. Provide a real prediction set to explain.
2. Focus on a constant target value and analyze its explainability and associated
probabilities (using a constant `pd.Series` passed during `compile()` or `add()`).
See the **Local Plot** tutorial for detailed examples.
Parameters
----------
features_to_hide : list of str, optional
List of feature names to hide from the output summary.
threshold : float, optional
Absolute value threshold below which feature contributions are hidden.
positive : bool, optional
Determines which contribution signs to hide:
- `True`: hide negative values.
- `False`: hide positive values.
- `None` (default): show all contributions.
max_contrib : int, optional
Maximum number of top feature contributions to include for each sample.
Default is 5.
proba : bool, optional
If `True`, adds predicted probability values to the output DataFrame.
Default is `False`.
use_groups : bool, optional
If `True`, aggregates feature contributions by groups defined in
`features_groups` (if available).
Default automatically activates grouping if `features_groups` were defined
during `compile()`.
Returns
-------
pandas.DataFrame
A DataFrame summarizing local explanations for each sample.
Columns typically include:
- Predicted class or value (`pred`)
- Probability (`proba`, if `proba=True`)
- Top N feature names, values, and corresponding contributions
Raises
------
ValueError
If predictions (`y_pred`) are missing.
Use `compile()` or `add()` before calling this method.
Example
-------
>>> # Export a summary of local explanations with probabilities
>>> summary_df = xpl.to_pandas(max_contrib=2, proba=True)
>>> summary_df.head()
pred proba feature_1 value_1 contribution_1 feature_2 value_2 contribution_2
0 0 0.756416 Sex 1.0 0.322308 Pclass 3.0 0.155069
1 3 0.628911 Sex 2.0 0.585475 Pclass 1.0 0.370504
2 0 0.543308 Sex 2.0 -0.486667 Pclass 3.0 0.255072
"""
use_groups = True if (use_groups is not False and self.features_groups is not None) else False
if use_groups:
data = self.data_groups
else:
data = self.data
# Classification: y_pred is needed
if self.y_pred is None:
raise ValueError("You have to specify y_pred argument. Please use add() or compile() method")
# Apply filter method if necessary
if (
all(var is None for var in [features_to_hide, threshold, positive, max_contrib])
and hasattr(self, "mask_params")
and (
# if the already computed mask does not have the right shape (this can happen when
# we use groups of features once and then use method without groups)
(
isinstance(data["contrib_sorted"], pd.DataFrame)
and len(data["contrib_sorted"].columns) == len(self.mask.columns)
)
or (
isinstance(data["contrib_sorted"], list)
and len(data["contrib_sorted"][0].columns) == len(self.mask[0].columns)
)
)
):
print("to_pandas params: " + str(self.mask_params))
else:
self.filter(
features_to_hide=features_to_hide,
threshold=threshold,
positive=positive,
max_contrib=max_contrib,
display_groups=use_groups,
)
if use_groups:
columns_dict = {i: col for i, col in enumerate(self.x_init_groups.columns)}
else:
columns_dict = self.columns_dict
# Summarize information
data["summary"] = self.state.summarize(
data["contrib_sorted"], data["var_dict"], data["x_sorted"], self.mask, columns_dict, self.features_dict
)
# Matching with y_pred
if proba:
self.predict_proba()
proba_values = self.proba_values
else:
proba_values = None
y_pred, summary = keep_right_contributions(
self.y_pred, data["summary"], self._case, self._classes, self.label_dict, proba_values
)
return pd.concat([y_pred, summary], axis=1)
def compute_features_import(self, force=False, local=False):
"""
Compute the relative feature importance based on contribution magnitudes.
This method calculates the global feature importance as the sum of the absolute
values of feature contributions across all samples.
The importance values are normalized on a base-100 scale.
For models with defined feature groups, grouped importances are also computed.
Optionally, local-level importances can be generated to capture finer-grained
feature effects at multiple neighborhood scales.
Parameters
----------
force : bool, optional
If `True`, recomputes feature importance even if it has already been calculated.
Default is `False`.
local : bool, optional
If `True`, computes additional local-level importances at multiple aggregation
scales (level 1 and level 2).
Default is `False`.
Returns
-------
pandas.Series or list of pandas.Series
- **Regression:** a single `Series` with one row per feature.
- **Classification:** a list of `Series`, one per class label.
Each `Series` represents the normalized feature importances,
indexed by feature name.
Notes
-----
- Feature importances are computed using the backend’s `get_global_features_importance` method.
- Grouped importances are computed if `features_groups` are defined.
- When `local=True`, additional granular importances are computed with
alternative normalization factors (norm=3 and norm=7).
Example
-------
>>> # Compute standard global feature importance
>>> xpl.compute_features_import()
>>> # Compute both global and local-level importances
>>> xpl.compute_features_import(local=True)
>>> xpl.features_imp.head()
"""
self.features_imp = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=1
)
if self.features_groups is not None and self.features_imp_groups is None:
self.features_imp_groups = self.state.compute_features_import(self.contributions_groups, norm=1)
if local:
self.features_imp_local_lev1 = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=3
)
self.features_imp_local_lev2 = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=7
)
if self.features_groups is not None:
self.features_imp_groups_local_lev1 = self.state.compute_features_import(
self.contributions_groups, norm=3
)
self.features_imp_groups_local_lev2 = self.state.compute_features_import(
self.contributions_groups, norm=7
)
def compute_features_stability(self, selection):
"""
Compute feature stability metrics for a given selection of instances.
This method calculates how stable feature contributions are within the
neighborhood of selected samples.
The resulting metrics are used in the visualizations
`local_neighbors_plot` and `local_stability_plot`.
Behavior depends on the size of the selection:
- **Single instance:** returns the normalized contribution values of the
instance and its neighbors (`norm_shap`).
- **Multiple instances:** returns the average normalized contributions
(`amplitude`) and their variability across neighborhoods (`variability`).
Parameters
----------
selection : list of int
Indices of samples in `x_encoded` for which to compute stability metrics.
Each index corresponds to a row in the dataset.
Returns
-------
dict
Dictionary containing arrays to be displayed in stability plots:
- `"amplitude"` : average normalized contribution values of selected instances and their neighbors
- `"variability"` : variation in contributions across the neighborhood
- `"norm_shap"` : normalized SHAP (or contribution) values for the selected instance(s)
Raises
------
AssertionError
If the explainer handles a multi-class classification problem (currently unsupported).
Notes
-----
- Only binary classification and regression tasks are supported.
- For each instance, nearest neighbors are identified using the encoded data (`x_encoded`).
- Contributions are normalized to enable comparison across samples.
Example
-------
>>> # Compute stability for a single instance
>>> xpl.compute_features_stability(selection=[5])
>>> xpl.local_neighbors["norm_shap"]
>>> # Compute stability for multiple instances
>>> xpl.compute_features_stability(selection=[2, 8, 12])
>>> xpl.features_stability["variability"].shape
"""
if (self._case == "classification") and (len(self._classes) > 2):
raise AssertionError("Multi-class classification is not supported")
all_neighbors = find_neighbors(selection, self.x_encoded, self.model, self._case)
# Check if entry is a single instance or not
if len(selection) == 1:
# Compute explanations for instance and neighbors
norm_shap, _, _ = shap_neighbors(all_neighbors[0], self.x_encoded, self.contributions, self._case)
self.local_neighbors = {"norm_shap": norm_shap}
else:
numb_expl = len(selection)
amplitude = np.zeros((numb_expl, self.x_init.shape[1]))
variability = np.zeros((numb_expl, self.x_init.shape[1]))
# For each instance (+ neighbors), compute explanation
for i in range(numb_expl):
(
_,
variability[i, :],
amplitude[i, :],
) = shap_neighbors(all_neighbors[i], self.x_encoded, self.contributions, self._case)
self.features_stability = {"variability": variability, "amplitude": amplitude}
def compute_features_compacity(self, selection, distance, nb_features):
"""
Compute feature compacity metrics for a given selection of instances.
This method evaluates how efficiently a model’s predictions can be
approximated using only a subset of features. It returns:
- the minimum number of features needed to reach a specified approximation level, and
- the approximation level reached with a given number of features.
These metrics are used in the `compacity_plot` visualization to illustrate
the trade-off between explanation simplicity and fidelity.
Parameters
----------
selection : list of int
Indices of samples in `x_encoded` for which to compute compacity metrics.
distance : float
Target approximation level (between 0 and 1) indicating how close
the reduced-feature model should be to the full model.
nb_features : int
Number of features to use when computing the achieved approximation.
Raises
------
AssertionError
If the explainer handles a multi-class classification problem (currently unsupported).
Returns
-------
dict
Dictionary containing:
- `"features_needed"` : number of features required to reach the target approximation level
- `"distance_reached"` : approximation level achieved using the given number of features
Notes
-----
- Only regression and binary classification tasks are supported.
- Approximation values are clipped between 0 and 1.
- Feature compacity measures how well the model’s predictions can be summarized
with fewer explanatory variables.
Example
-------
>>> xpl.compute_features_compacity(selection=[0, 5, 10], distance=0.9, nb_features=10)
>>> xpl.features_compacity["features_needed"]
"""
if (self._case == "classification") and (len(self._classes) > 2):
raise AssertionError("Multi-class classification is not supported")
features_needed = get_min_nb_features(selection, self.contributions, self._case, distance)
distance_reached = get_distance(selection, self.contributions, self._case, nb_features)
# We clip large approximations to 100%
distance_reached = np.clip(distance_reached, 0, 1)
self.features_compacity = {"features_needed": features_needed, "distance_reached": distance_reached}
def init_app(self, settings: dict = None):
"""
Initialize a SmartApp instance for the current SmartExplainer object.
This method provides a simple way to create and configure the Shapash
WebApp (`SmartApp`) when it is hosted or launched through a custom setup,
rather than via the standard `run_app()` method.
Parameters
----------
settings : dict, optional
Dictionary specifying default configuration values for the WebApp.
Possible keys include:
- `'rows'` : int — number of rows to display by default
- `'points'` : int — number of points shown in scatter plots
- `'violin'` : int — number of points displayed in violin plots
- `'features'` : int — number of features shown in plots
- `'toggle_group'` : bool — default state of the group toggle in the UI
All integer values must be positive.
Returns
-------
None
Initializes the `smartapp` attribute with a configured `SmartApp` instance.
Example
-------
>>> # Initialize SmartApp with custom settings
>>> xpl.init_app(settings={"rows": 100, "features": 10})
>>> xpl.smartapp.run()
"""
self.smartapp = SmartApp(self, settings)
[docs] def run_app(
self,
port: int = None,
host: str = None,
title_story: str = None,
settings: dict = None,
) -> CustomThread:
"""
Launch the Shapash interpretability WebApp associated with this SmartExplainer.
This method starts the interactive Shapash WebApp that enables users to
explore model predictions, feature importances, and local explanations
directly in their browser.
It can be called directly from a Jupyter notebook — the application link
will appear in the notebook output.
To stop the running app, use the `.kill()` method on the returned object.
Examples of usage are provided in the **WebApp tutorial** in the Shapash documentation.
Parameters
----------
port : int, optional
Port number for the WebApp server.
Defaults to `8050` if not specified.
host : str, optional
Host address for the WebApp server.
Defaults to `"0.0.0.0"`, allowing external access.
title_story : str, optional
Custom title to display in the WebApp interface.
This title can also be reused in reports or other visualizations.
settings : dict, optional
Dictionary specifying default configuration values for the WebApp.
Possible keys include:
- `'rows'` : int — number of rows displayed by default
- `'points'` : int — number of points in scatter plots
- `'violin'` : int — number of points in violin plots
- `'features'` : int — number of features shown in graphs
All values must be positive integers.
Returns
-------
CustomThread
A thread instance running the WebApp server.
Raises
------
ValueError
If the SmartExplainer has not been compiled before launching the app.
Example
-------
>>> # Launch the WebApp in a Jupyter notebook
>>> app = xpl.run_app(port=8050)
>>> # Stop the app
>>> app.kill()
"""
if title_story is not None:
self.title_story = title_story
if hasattr(self, "_case"):
self.smartapp = SmartApp(self, settings)
if host is None:
host = "0.0.0.0"
if port is None:
port = 8050
host_name = get_host_name()
server_instance = CustomThread(
target=lambda: self.smartapp.app.run_server(debug=False, host=host, port=port)
)
if host_name is None:
host_name = host
elif host != "0.0.0.0":
host_name = host
server_instance.start()
logging.info(f"Your Shapash application run on http://{host_name}:{port}/")
logging.info("Use the method .kill() to down your app.")
return server_instance
else:
raise ValueError("Explainer must be compiled before running app.")
def to_smartpredictor(self):
"""
Create and return a SmartPredictor object derived from the current SmartExplainer instance.
This method builds a `SmartPredictor` — a lightweight, production-oriented object
that encapsulates all necessary components from the `SmartExplainer` to generate
model predictions and interpretability outputs without requiring re-explanation.
The generated `SmartPredictor` includes the model, preprocessing and postprocessing
steps, feature and label mappings, and backend configuration used to compute
contributions.
Returns
-------
SmartPredictor
A `SmartPredictor` instance initialized with the relevant attributes
from the current `SmartExplainer`.
Raises
------
ValueError
If no backend is defined in the current `SmartExplainer`.
Attributes Transferred
----------------------
- **features_dict** : dict
Mapping from technical feature names to human-readable (domain) names.
- **label_dict** : dict
Mapping from integer labels to domain names (classification target values).
- **columns_dict** : dict
Mapping from integer column indices to technical feature names.
- **features_types** : dict
Mapping from feature names to their inferred data types.
- **model** : object
The trained model used for prediction.
- **backend** : BaseBackend
The backend used to compute feature contributions (e.g., SHAP, LIME).
- **preprocessing** : category_encoders object, ColumnTransformer, list, or dict
Preprocessing transformations applied to the original data.
- **postprocessing** : dict
Postprocessing transformations applied after inverse preprocessing.
- **features_groups** : dict, optional
Feature grouping structure, if defined during compilation.
- **_case** : str
Indicates whether the task is `"classification"` or `"regression"`.
- **_classes** : list or None
List of class labels for classification models, `None` for regression.
- **mask_params** : dict, optional
Parameters defining contribution filters used to summarize local explainability.
Example
-------
>>> # Convert a SmartExplainer into a deployable SmartPredictor
>>> sp = xpl.to_smartpredictor()
>>> sp.predict(data_sample)
>>> sp.explain(data_sample)
"""
if self.backend is None:
raise ValueError(
"""
SmartPredictor needs a backend (explainer).
Please compile without contributions or specify the
explainer used. Make change in compile() step.
"""
)
self.features_types = {features: str(self.x_init[features].dtypes) for features in self.x_init.columns}
listattributes = [
"features_dict",
"model",
"columns_dict",
"backend",
"features_types",
"label_dict",
"preprocessing",
"postprocessing",
"features_groups",
]
params_smartpredictor = [self.check_attributes(attribute) for attribute in listattributes]
if not hasattr(self, "mask_params"):
self.mask_params = {"features_to_hide": None, "threshold": None, "positive": None, "max_contrib": None}
params_smartpredictor.append(self.mask_params)
return shapash.explainer.smart_predictor.SmartPredictor(*params_smartpredictor)
def check_x_y_attributes(self, x_str, y_str):
"""
Validate and retrieve two attributes from the SmartExplainer instance.
This method checks whether the given attribute names exist in the current
`SmartExplainer` object. It returns the corresponding attribute values if found,
or `None` for any attribute that does not exist.
Parameters
----------
x_str : str
Name of the first attribute to check.
y_str : str
Name of the second attribute to check.
Returns
-------
list
A two-element list containing the retrieved attributes in order:
`[x_attribute, y_attribute]`.
Each element is the attribute’s value if it exists, otherwise `None`.
Raises
------
ValueError
If either `x_str` or `y_str` is not provided as a string.
Example
-------
>>> x_attr, y_attr = xpl.check_x_y_attributes("x_encoded", "y_pred")
>>> print(x_attr.shape, y_attr.shape)
"""
if not (isinstance(x_str, str) and isinstance(y_str, str)):
raise ValueError(
"""
x and y must be strings.
"""
)
params_checkypred = []
attributs_explainer = [x_str, y_str]
for attribut in attributs_explainer:
if hasattr(self, attribut):
params_checkypred.append(self.__dict__[attribut])
else:
params_checkypred.append(None)
return params_checkypred
[docs] def generate_report(
self,
output_file,
project_info_file,
x_train=None,
y_train=None,
y_test=None,
title_story=None,
title_description=None,
metrics=None,
working_dir=None,
notebook_path=None,
kernel_name=None,
max_points=200,
display_interaction_plot=False,
nb_top_interactions=5,
):
"""
Generate an interactive HTML report summarizing the model and its explainability.
This method produces a comprehensive HTML report containing visual and textual
insights about the project, dataset, and model performance.
It leverages a predefined or custom Jupyter notebook template to analyze
the model, generate plots, compute metrics, and export the final report.
A project information YAML file is required to describe key project details
(e.g., model name, author, date, context).
Parameters
----------
output_file : str
Path to the output HTML file where the report will be saved.
project_info_file : str
Path to a YAML file containing project metadata to be displayed in the report
(e.g., project name, author, date, description).
x_train : pandas.DataFrame, optional
Training dataset used to fit the model.
Used for generating feature summaries and training-related analyses.
y_train : pandas.Series or pandas.DataFrame, optional
Target values corresponding to `x_train`.
y_test : pandas.Series or pandas.DataFrame, optional
Target values for the test dataset.
title_story : str, optional
Title displayed at the top of the report.
title_description : str, optional
Short descriptive text displayed below the main title.
metrics : list of dict, optional
List of metrics to compute and display in the performance section.
Each dictionary should include:
- `'path'`: str — import path to the metric function (e.g., `"sklearn.metrics.f1_score"`)
- `'name'`: str, optional — display name for the metric
- `'use_proba_values'`: bool, optional — if True, use predicted probabilities instead of labels
Example:
`metrics=[{'name': 'F1 score', 'path': 'sklearn.metrics.f1_score'}]`
working_dir : str, optional
Directory used to temporarily store generated files (e.g., notebook, outputs).
If `None`, a temporary directory is automatically created and deleted after report generation.
notebook_path : str, optional
Path to a custom notebook used as a template for generating the report.
If `None`, the default Shapash report notebook is used.
kernel_name : str, optional
Name of the Jupyter kernel to use for report execution.
Useful when multiple kernels are available and the default one is incorrect.
max_points : int, optional, default=200
Maximum number of points displayed in contribution plots.
display_interaction_plot : bool, optional, default=False
If True, includes interaction plots in the report.
(Note: this can increase computation time.)
nb_top_interactions : int, optional, default=5
Number of top feature interactions to include in the report.
Returns
-------
None
The report is saved as an HTML file at the specified `output_file` location.
Raises
------
AssertionError
If the SmartExplainer instance is not compiled before report generation.
Exception
If an unexpected error occurs during report execution or export.
Notes
-----
- The method internally executes a notebook that generates the report content.
- Temporary files are automatically cleaned up unless a custom `working_dir` is provided.
- Interaction plots can be disabled to optimize runtime performance.
Example
-------
>>> xpl.generate_report(
... output_file="report.html",
... project_info_file="utils/project_info.yml",
... x_train=x_train,
... y_train=y_train,
... y_test=y_test,
... title_story="House Prices Project Report",
... title_description="Comprehensive interpretability analysis for the Kaggle house prices dataset.",
... metrics=[
... {"path": "sklearn.metrics.mean_squared_error", "name": "Mean Squared Error"},
... {"path": "sklearn.metrics.mean_absolute_error", "name": "Mean Absolute Error"},
... ],
... display_interaction_plot=True,
... nb_top_interactions=5,
... )
"""
check_report_requirements()
if x_train is not None:
x_train = handle_categorical_missing(x_train)
# Avoid Import Errors with requirements specific to the Shapash Report
from shapash.report.generation import execute_report, export_and_save_report
rm_working_dir = False
if not working_dir:
working_dir = tempfile.mkdtemp()
rm_working_dir = True
if not hasattr(self, "model"):
raise AssertionError(
"Explainer object was not compiled. Please compile the explainer "
"object using .compile(...) method before generating the report."
)
try:
execute_report(
working_dir=working_dir,
explainer=self,
project_info_file=project_info_file,
x_train=x_train,
y_train=y_train,
y_test=y_test,
config={
k: v
for k, v in dict(
title_story=title_story,
title_description=title_description,
metrics=metrics,
max_points=max_points,
display_interaction_plot=display_interaction_plot,
nb_top_interactions=nb_top_interactions,
).items()
if v is not None
},
notebook_path=notebook_path,
kernel_name=kernel_name,
)
export_and_save_report(working_dir=working_dir, output_file=output_file)
if rm_working_dir:
shutil.rmtree(working_dir)
except Exception as e:
if rm_working_dir:
shutil.rmtree(working_dir)
raise e
def _local_pred(self, index, label=None):
"""
Compute the model prediction or probability for a single observation.
This internal method retrieves the prediction or class probability
corresponding to a specific sample index.
Parameters
----------
index : int, str, or float
Index of the sample for which to compute the prediction.
Must correspond to a valid index in `x_encoded`.
label : int, optional
Class label for which to extract the probability in classification tasks.
If `None`, the method returns the prediction for the main target.
Returns
-------
float
The predicted value (for regression) or predicted probability
(for classification).
Notes
-----
- For classification, returns the class probability if `proba_values` are available.
- For regression, returns the predicted numeric value.
- This is an internal helper used primarily for visualization.
Example
-------
>>> # Retrieve the predicted value for observation at index 12
>>> xpl._local_pred(index=12)
0.7421
"""
if self._case == "classification":
if self.proba_values is not None:
value = self.proba_values.iloc[:, [label]].loc[index].values[0]
else:
value = None
elif self._case == "regression":
if self.y_pred is not None:
value = self.y_pred.loc[index]
else:
value = self.model.predict(self.x_encoded.loc[[index]])[0]
if isinstance(value, pd.Series):
value = value.values[0]
return value