{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ColumnTransformer tutorial\n", "\n", "This tutorial shows how to use ColumnTransformer to reverse data preprocessing and display explicit labels\n", "\n", "We used Kaggle's [Titanic](https://www.kaggle.com/c/titanic) dataset\n", "\n", "Content :\n", "- Encode data with ColumnTransformer\n", "- Build a Binary Classifier (Random Forest)\n", "- Using Shapash\n", "- Show inversed data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We implement an inverse transform function for ColumnTransformer based on column position.\n", "\n", "The top-Transform feature obtained after the ColumnTransformer shouldn't be sampled." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from xgboost import XGBClassifier\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.preprocessing import OrdinalEncoder\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load titanic Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from shapash.data.data_loader import data_loading\n", "\n", "titan_df, titan_dict = data_loading('titanic')\n", "del titan_df['Name']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassSexAgeSibSpParchFareEmbarkedTitle
PassengerId
10Third classmale22.0107.25SouthamptonMr
21First classfemale38.01071.28CherbourgMrs
31Third classfemale26.0007.92SouthamptonMiss
41First classfemale35.01053.10SouthamptonMrs
50Third classmale35.0008.05SouthamptonMr
\n", "
" ], "text/plain": [ " Survived Pclass Sex Age SibSp Parch Fare \\\n", "PassengerId \n", "1 0 Third class male 22.0 1 0 7.25 \n", "2 1 First class female 38.0 1 0 71.28 \n", "3 1 Third class female 26.0 0 0 7.92 \n", "4 1 First class female 35.0 1 0 53.10 \n", "5 0 Third class male 35.0 0 0 8.05 \n", "\n", " Embarked Title \n", "PassengerId \n", "1 Southampton Mr \n", "2 Cherbourg Mrs \n", "3 Southampton Miss \n", "4 Southampton Mrs \n", "5 Southampton Mr " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "titan_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare data for the model\n", "\n", "Create Target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "y = titan_df.reset_index(drop=True)['Survived']\n", "X = titan_df.drop('Survived', axis=1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 0\n", "1 1\n", "2 1\n", "3 1\n", "4 0\n", " ..\n", "886 0\n", "887 1\n", "888 0\n", "889 1\n", "890 0\n", "Name: Survived, Length: 891, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "titan_df.reset_index(drop=True)['Survived']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train a columns transformer with multiple transformers" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "enc_columntransfo = ColumnTransformer(\n", " transformers=[\n", " ('onehot', OneHotEncoder(), ['Pclass','Sex']),\n", " ('ordinal', OrdinalEncoder(), ['Embarked','Title'])\n", " ],\n", " remainder='passthrough')\n", "X_transform = pd.DataFrame(enc_columntransfo.fit_transform(X, y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reaffect columns name for the remainder part." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456AgeSibSpParchFare
00.00.01.00.01.02.011.022.01.00.07.25
11.00.00.01.00.00.012.038.01.00.071.28
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 Age SibSp Parch Fare\n", "0 0.0 0.0 1.0 0.0 1.0 2.0 11.0 22.0 1.0 0.0 7.25\n", "1 1.0 0.0 0.0 1.0 0.0 0.0 12.0 38.0 1.0 0.0 71.28" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#find index that didn't get transformation\n", "idx_col = enc_columntransfo.transformers_[2][2]\n", "#give the N-last index, the remainder index name\n", "start = len(X_transform.columns)-len(idx_col)\n", "X_transform.columns = X_transform.columns.tolist()[:start]+X.columns[idx_col].tolist()\n", "X_transform.head(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit a model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,\n", " importance_type='gain', interaction_constraints=None,\n", " learning_rate=0.300000012, max_delta_step=0, max_depth=6,\n", " min_child_weight=2, missing=nan, monotone_constraints=None,\n", " n_estimators=200, n_jobs=0, num_parallel_tree=1, random_state=0,\n", " reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,\n", " tree_method=None, validate_parameters=False, verbosity=None)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Xtrain, Xtest, ytrain, ytest = train_test_split(X_transform, y, train_size=0.75, random_state=1)\n", "\n", "clf = XGBClassifier(n_estimators=200,min_child_weight=2).fit(Xtrain,ytrain)\n", "clf.fit(Xtrain, ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Shapash" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from shapash import SmartExplainer" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "xpl = SmartExplainer(model=clf, preprocessing=enc_columntransfo)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "xpl.compile(x=Xtest,\n", "y_target=ytest, # Optional: allows to display True Values vs Predicted Values\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize data in pandas" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
onehot_Pclassonehot_Sexordinal_Embarkedordinal_TitleAgeSibSpParchFare
862First classfemaleSouthamptonMrs48.00.00.025.93
223Third classmaleSouthamptonMr29.50.00.07.90
84Second classfemaleSouthamptonMiss17.00.00.010.50
680Third classfemaleQueenstownMiss29.50.00.08.14
\n", "
" ], "text/plain": [ " onehot_Pclass onehot_Sex ordinal_Embarked ordinal_Title Age SibSp \\\n", "862 First class female Southampton Mrs 48.0 0.0 \n", "223 Third class male Southampton Mr 29.5 0.0 \n", "84 Second class female Southampton Miss 17.0 0.0 \n", "680 Third class female Queenstown Miss 29.5 0.0 \n", "\n", " Parch Fare \n", "862 0.0 25.93 \n", "223 0.0 7.90 \n", "84 0.0 10.50 \n", "680 0.0 8.14 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Cause in ColumnsTransformer we apply multiple transformer on the same column.\n", "#the Pclass column is now : TransformersName + Pclass\n", "xpl.x_init.head(4)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456AgeSibSpParchFare
8621.00.00.01.00.02.012.048.00.00.025.93
2230.00.01.00.01.02.011.029.50.00.07.90
840.01.00.01.00.02.08.017.00.00.010.50
6800.00.01.01.00.01.08.029.50.00.08.14
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 Age SibSp Parch Fare\n", "862 1.0 0.0 0.0 1.0 0.0 2.0 12.0 48.0 0.0 0.0 25.93\n", "223 0.0 0.0 1.0 0.0 1.0 2.0 11.0 29.5 0.0 0.0 7.90\n", "84 0.0 1.0 0.0 1.0 0.0 2.0 8.0 17.0 0.0 0.0 10.50\n", "680 0.0 0.0 1.0 1.0 0.0 1.0 8.0 29.5 0.0 0.0 8.14" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xpl.x_encoded.head(4)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "shapash_picking", "language": "python", "name": "shapash_picking" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }