{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ColumnTransformer tutorial\n",
"\n",
"This tutorial shows how to use ColumnTransformer to reverse data preprocessing and display explicit labels\n",
"\n",
"We used Kaggle's [Titanic](https://www.kaggle.com/c/titanic) dataset\n",
"\n",
"Content :\n",
"- Encode data with ColumnTransformer\n",
"- Build a Binary Classifier (Random Forest)\n",
"- Using Shapash\n",
"- Show inversed data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We implement an inverse transform function for ColumnTransformer based on column position.\n",
"\n",
"The top-Transform feature obtained after the ColumnTransformer shouldn't be sampled."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from xgboost import XGBClassifier\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load titanic Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from shapash.data.data_loader import data_loading\n",
"\n",
"titan_df, titan_dict = data_loading('titanic')\n",
"del titan_df['Name']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Survived | \n",
" Pclass | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Embarked | \n",
" Title | \n",
"
\n",
" \n",
" PassengerId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0 | \n",
" Third class | \n",
" male | \n",
" 22.0 | \n",
" 1 | \n",
" 0 | \n",
" 7.25 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
" 2 | \n",
" 1 | \n",
" First class | \n",
" female | \n",
" 38.0 | \n",
" 1 | \n",
" 0 | \n",
" 71.28 | \n",
" Cherbourg | \n",
" Mrs | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" Third class | \n",
" female | \n",
" 26.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.92 | \n",
" Southampton | \n",
" Miss | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" First class | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 53.10 | \n",
" Southampton | \n",
" Mrs | \n",
"
\n",
" \n",
" 5 | \n",
" 0 | \n",
" Third class | \n",
" male | \n",
" 35.0 | \n",
" 0 | \n",
" 0 | \n",
" 8.05 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare \\\n",
"PassengerId \n",
"1 0 Third class male 22.0 1 0 7.25 \n",
"2 1 First class female 38.0 1 0 71.28 \n",
"3 1 Third class female 26.0 0 0 7.92 \n",
"4 1 First class female 35.0 1 0 53.10 \n",
"5 0 Third class male 35.0 0 0 8.05 \n",
"\n",
" Embarked Title \n",
"PassengerId \n",
"1 Southampton Mr \n",
"2 Cherbourg Mrs \n",
"3 Southampton Miss \n",
"4 Southampton Mrs \n",
"5 Southampton Mr "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titan_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data for the model\n",
"\n",
"Create Target"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"y = titan_df.reset_index(drop=True)['Survived']\n",
"X = titan_df.drop('Survived', axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 0\n",
"1 1\n",
"2 1\n",
"3 1\n",
"4 0\n",
" ..\n",
"886 0\n",
"887 1\n",
"888 0\n",
"889 1\n",
"890 0\n",
"Name: Survived, Length: 891, dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titan_df.reset_index(drop=True)['Survived']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train a columns transformer with multiple transformers"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"enc_columntransfo = ColumnTransformer(\n",
" transformers=[\n",
" ('onehot', OneHotEncoder(), ['Pclass','Sex']),\n",
" ('ordinal', OrdinalEncoder(), ['Embarked','Title'])\n",
" ],\n",
" remainder='passthrough')\n",
"X_transform = pd.DataFrame(enc_columntransfo.fit_transform(X, y))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reaffect columns name for the remainder part."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 6 | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 2.0 | \n",
" 11.0 | \n",
" 22.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 7.25 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 12.0 | \n",
" 38.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 71.28 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4 5 6 Age SibSp Parch Fare\n",
"0 0.0 0.0 1.0 0.0 1.0 2.0 11.0 22.0 1.0 0.0 7.25\n",
"1 1.0 0.0 0.0 1.0 0.0 0.0 12.0 38.0 1.0 0.0 71.28"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#find index that didn't get transformation\n",
"idx_col = enc_columntransfo.transformers_[2][2]\n",
"#give the N-last index, the remainder index name\n",
"start = len(X_transform.columns)-len(idx_col)\n",
"X_transform.columns = X_transform.columns.tolist()[:start]+X.columns[idx_col].tolist()\n",
"X_transform.head(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit a model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,\n",
" colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,\n",
" importance_type='gain', interaction_constraints=None,\n",
" learning_rate=0.300000012, max_delta_step=0, max_depth=6,\n",
" min_child_weight=2, missing=nan, monotone_constraints=None,\n",
" n_estimators=200, n_jobs=0, num_parallel_tree=1, random_state=0,\n",
" reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,\n",
" tree_method=None, validate_parameters=False, verbosity=None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Xtrain, Xtest, ytrain, ytest = train_test_split(X_transform, y, train_size=0.75, random_state=1)\n",
"\n",
"clf = XGBClassifier(n_estimators=200,min_child_weight=2).fit(Xtrain,ytrain)\n",
"clf.fit(Xtrain, ytrain)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Shapash"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from shapash import SmartExplainer"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"xpl = SmartExplainer(model=clf, preprocessing=enc_columntransfo)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"xpl.compile(x=Xtest,\n",
"y_target=ytest, # Optional: allows to display True Values vs Predicted Values\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize data in pandas"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" onehot_Pclass | \n",
" onehot_Sex | \n",
" ordinal_Embarked | \n",
" ordinal_Title | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
"
\n",
" \n",
" \n",
" \n",
" 862 | \n",
" First class | \n",
" female | \n",
" Southampton | \n",
" Mrs | \n",
" 48.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 25.93 | \n",
"
\n",
" \n",
" 223 | \n",
" Third class | \n",
" male | \n",
" Southampton | \n",
" Mr | \n",
" 29.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 7.90 | \n",
"
\n",
" \n",
" 84 | \n",
" Second class | \n",
" female | \n",
" Southampton | \n",
" Miss | \n",
" 17.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 10.50 | \n",
"
\n",
" \n",
" 680 | \n",
" Third class | \n",
" female | \n",
" Queenstown | \n",
" Miss | \n",
" 29.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 8.14 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" onehot_Pclass onehot_Sex ordinal_Embarked ordinal_Title Age SibSp \\\n",
"862 First class female Southampton Mrs 48.0 0.0 \n",
"223 Third class male Southampton Mr 29.5 0.0 \n",
"84 Second class female Southampton Miss 17.0 0.0 \n",
"680 Third class female Queenstown Miss 29.5 0.0 \n",
"\n",
" Parch Fare \n",
"862 0.0 25.93 \n",
"223 0.0 7.90 \n",
"84 0.0 10.50 \n",
"680 0.0 8.14 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Cause in ColumnsTransformer we apply multiple transformer on the same column.\n",
"#the Pclass column is now : TransformersName + Pclass\n",
"xpl.x_init.head(4)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 6 | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
"
\n",
" \n",
" \n",
" \n",
" 862 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 12.0 | \n",
" 48.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 25.93 | \n",
"
\n",
" \n",
" 223 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 2.0 | \n",
" 11.0 | \n",
" 29.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 7.90 | \n",
"
\n",
" \n",
" 84 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 8.0 | \n",
" 17.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 10.50 | \n",
"
\n",
" \n",
" 680 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 8.0 | \n",
" 29.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 8.14 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4 5 6 Age SibSp Parch Fare\n",
"862 1.0 0.0 0.0 1.0 0.0 2.0 12.0 48.0 0.0 0.0 25.93\n",
"223 0.0 0.0 1.0 0.0 1.0 2.0 11.0 29.5 0.0 0.0 7.90\n",
"84 0.0 1.0 0.0 1.0 0.0 2.0 8.0 17.0 0.0 0.0 10.50\n",
"680 0.0 0.0 1.0 1.0 0.0 1.0 8.0 29.5 0.0 0.0 8.14"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xpl.x_encoded.head(4)"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "shapash_picking",
"language": "python",
"name": "shapash_picking"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}