{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Shapash with Lime explainer - Titanic\n", "\n", "You can compute your local contributions with the [Lime](https://github.com/marcotcr/lime) library and summarize them with Shapash\n", "\n", "Contents:\n", "- Build a Binary Classifier (Random Forest)\n", "- Create Explainer using Lime\n", "- Compile Shapash SmartExplainer\n", "- Display local_plot\n", "- to_pandas export\n", "\n", "Data from Kaggle [Titanic](https://www.kaggle.com/c/titanic)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from category_encoders import OrdinalEncoder\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "import lime.lime_tabular" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from shapash.data.data_loader import data_loading" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "titan_df, titan_dict = data_loading('titanic')\n", "del titan_df['Name']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassSexAgeSibSpParchFareEmbarkedTitle
PassengerId
10Third classmale22.0107.25SouthamptonMr
21First classfemale38.01071.28CherbourgMrs
31Third classfemale26.0007.92SouthamptonMiss
41First classfemale35.01053.10SouthamptonMrs
50Third classmale35.0008.05SouthamptonMr
\n", "
" ], "text/plain": [ " Survived Pclass Sex Age SibSp Parch Fare \\\n", "PassengerId \n", "1 0 Third class male 22.0 1 0 7.25 \n", "2 1 First class female 38.0 1 0 71.28 \n", "3 1 Third class female 26.0 0 0 7.92 \n", "4 1 First class female 35.0 1 0 53.10 \n", "5 0 Third class male 35.0 0 0 8.05 \n", "\n", " Embarked Title \n", "PassengerId \n", "1 Southampton Mr \n", "2 Cherbourg Mrs \n", "3 Southampton Miss \n", "4 Southampton Mrs \n", "5 Southampton Mr " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "titan_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Classification Model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "y = titan_df['Survived']\n", "X = titan_df.drop('Survived', axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "varcat=['Pclass','Sex','Embarked','Title']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "categ_encoding = OrdinalEncoder(cols=varcat, \\\n", " handle_unknown='ignore', \\\n", " return_df=True).fit(X)\n", "X = categ_encoding.transform(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train Test split + Random Forest fit" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", " criterion='gini', max_depth=None, max_features='auto',\n", " max_leaf_nodes=None, max_samples=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=3, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=100,\n", " n_jobs=None, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=0.75, random_state=1)\n", "\n", "rf = RandomForestClassifier(n_estimators=100,min_samples_leaf=3)\n", "rf.fit(Xtrain, ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Lime Explainer" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "#Training Tabular Explainer\n", "explainer = lime.lime_tabular.LimeTabularExplainer(Xtrain.values, \n", " mode='classification',\n", " feature_names=Xtrain.columns,\n", " class_names=ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Apply Explainer to Test Sample And Preprocessing" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Function features_check Extract feature names from Lime Output to be used by shapash\n", "def features_check(s):\n", " for w in list(Xtest.columns):\n", " if f' {w} ' in f' {s} ' :\n", " feat = w\n", " return feat" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 57.8 s, sys: 7.34 s, total: 1min 5s\n", "Wall time: 10.9 s\n" ] } ], "source": [ "%%time\n", "# Compute local Lime Explanation for each row in Test Sample\n", "contrib_l=[]\n", "for ind in Xtest.index:\n", " exp = explainer.explain_instance(Xtest.loc[ind].values, rf.predict_proba, num_features=Xtest.shape[1])\n", " contrib_l.append(dict([[features_check(elem[0]),elem[1]] for elem in exp.as_list()]))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "contribution_df =pd.DataFrame(contrib_l,index=Xtest.index)\n", "# sorting the columns as in the original dataset\n", "contribution_df = contribution_df[list(Xtest.columns)]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "ypred=pd.DataFrame(rf.predict(Xtest),columns=['pred'],index=Xtest.index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use Shapash With Lime Contributions" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from shapash import SmartExplainer" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "xpl = SmartExplainer(\n", " model=rf,\n", " preprocessing=categ_encoding,\n", " features_dict=titan_dict\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use contributions parameter of compile method to declare Lime contributions" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "xpl.compile(\n", " contributions=contribution_df, # Lime Contribution pd.DataFrame\n", " y_pred=ypred,\n", " y_target=ytest, # Optional: allows to display True Values vs Predicted Values\n", " x=Xtest\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "xpl.plot.local_plot(index=3)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predprobafeature_1value_1contribution_1feature_2value_2contribution_2feature_3value_3contribution_3
86310.801675Sexfemale0.257817Title of passengerMrs0.188714Ticket classFirst class0.0880992
22400.965208Sexmale0.248462Title of passengerMr0.199544Ticket classThird class0.0838383
8510.799397Sexfemale0.25465Title of passengerMiss0.193198Age170.0981314
68110.786956Sexfemale0.252464Title of passengerMiss0.187045Relatives such as brother or wife00.0522808
53610.936170Sexfemale0.250703Title of passengerMiss0.193096Age70.104632
\n", "
" ], "text/plain": [ " pred proba feature_1 value_1 contribution_1 feature_2 \\\n", "863 1 0.801675 Sex female 0.257817 Title of passenger \n", "224 0 0.965208 Sex male 0.248462 Title of passenger \n", "85 1 0.799397 Sex female 0.25465 Title of passenger \n", "681 1 0.786956 Sex female 0.252464 Title of passenger \n", "536 1 0.936170 Sex female 0.250703 Title of passenger \n", "\n", " value_2 contribution_2 feature_3 value_3 \\\n", "863 Mrs 0.188714 Ticket class First class \n", "224 Mr 0.199544 Ticket class Third class \n", "85 Miss 0.193198 Age 17 \n", "681 Miss 0.187045 Relatives such as brother or wife 0 \n", "536 Miss 0.193096 Age 7 \n", "\n", " contribution_3 \n", "863 0.0880992 \n", "224 0.0838383 \n", "85 0.0981314 \n", "681 0.0522808 \n", "536 0.104632 " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary_df = xpl.to_pandas(max_contrib=3,positive=True,proba=True)\n", "summary_df.head()" ] } ], "metadata": { "celltoolbar": "Aucun(e)", "hide_input": false, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 4 }