{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Category_encoder tutorial\n",
"\n",
"This tutorial shows how to use category_encoder encoders to reverse data preprocessing and display explicit labels.\n",
"\n",
"We used Kaggle's [Titanic](https://www.kaggle.com/c/titanic) dataset.\n",
"\n",
"This Tutorial:\n",
"- Encode data with Category_encoder\n",
"- Build a Binary Classifier (Random Forest)\n",
"- Using Shapash\n",
"- Show inversed data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from category_encoders import OrdinalEncoder\n",
"from category_encoders import OneHotEncoder\n",
"from category_encoders import TargetEncoder\n",
"from xgboost import XGBClassifier\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load titanic Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from shapash.data.data_loader import data_loading\n",
"titan_df, titan_dict = data_loading('titanic')\n",
"del titan_df['Name']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Survived | \n",
" Pclass | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Embarked | \n",
" Title | \n",
"
\n",
" \n",
" PassengerId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0 | \n",
" Third class | \n",
" male | \n",
" 22.0 | \n",
" 1 | \n",
" 0 | \n",
" 7.25 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
" 2 | \n",
" 1 | \n",
" First class | \n",
" female | \n",
" 38.0 | \n",
" 1 | \n",
" 0 | \n",
" 71.28 | \n",
" Cherbourg | \n",
" Mrs | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" Third class | \n",
" female | \n",
" 26.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.92 | \n",
" Southampton | \n",
" Miss | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" First class | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 53.10 | \n",
" Southampton | \n",
" Mrs | \n",
"
\n",
" \n",
" 5 | \n",
" 0 | \n",
" Third class | \n",
" male | \n",
" 35.0 | \n",
" 0 | \n",
" 0 | \n",
" 8.05 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Survived Pclass Sex Age SibSp Parch Fare \\\n",
"PassengerId \n",
"1 0 Third class male 22.0 1 0 7.25 \n",
"2 1 First class female 38.0 1 0 71.28 \n",
"3 1 Third class female 26.0 0 0 7.92 \n",
"4 1 First class female 35.0 1 0 53.10 \n",
"5 0 Third class male 35.0 0 0 8.05 \n",
"\n",
" Embarked Title \n",
"PassengerId \n",
"1 Southampton Mr \n",
"2 Cherbourg Mrs \n",
"3 Southampton Miss \n",
"4 Southampton Mrs \n",
"5 Southampton Mr "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titan_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data for the model with Category Encoder\n",
"\n",
"Create Target"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"y = titan_df['Survived']\n",
"X = titan_df.drop('Survived', axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train category encoder"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#Train category encoder\n",
"onehot = OneHotEncoder(cols=['Pclass']).fit(X)\n",
"result_1 = onehot.transform(X)\n",
"ordinal = OrdinalEncoder(cols=['Embarked','Title']).fit(result_1)\n",
"result_2 = ordinal.transform(result_1)\n",
"target = TargetEncoder(cols=['Sex']).fit(result_2,y)\n",
"result_3 =target.transform(result_2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"encoder = [onehot,ordinal,target]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit a model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,\n",
" colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,\n",
" importance_type='gain', interaction_constraints=None,\n",
" learning_rate=0.300000012, max_delta_step=0, max_depth=6,\n",
" min_child_weight=2, missing=nan, monotone_constraints=None,\n",
" n_estimators=200, n_jobs=0, num_parallel_tree=1,\n",
" objective='binary:logistic', random_state=0, reg_alpha=0,\n",
" reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method=None,\n",
" validate_parameters=False, verbosity=None)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Xtrain, Xtest, ytrain, ytest = train_test_split(result_3, y, train_size=0.75, random_state=1)\n",
"\n",
"clf = XGBClassifier(n_estimators=200,min_child_weight=2).fit(Xtrain,ytrain)\n",
"clf.fit(Xtrain, ytrain)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Shapash"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from shapash import SmartExplainer"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"xpl = SmartExplainer(\n",
" model=clf,\n",
" preprocessing=encoder,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Backend: Shap TreeExplainer\n"
]
}
],
"source": [
"xpl.compile(x=Xtest,\n",
"y_target=ytest, # Optional: allows to display True Values vs Predicted Values\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize data in pandas"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Pclass | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Embarked | \n",
" Title | \n",
"
\n",
" \n",
" PassengerId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 863 | \n",
" First class | \n",
" female | \n",
" 48.0 | \n",
" 0 | \n",
" 0 | \n",
" 25.93 | \n",
" Southampton | \n",
" Mrs | \n",
"
\n",
" \n",
" 224 | \n",
" Third class | \n",
" male | \n",
" 29.5 | \n",
" 0 | \n",
" 0 | \n",
" 7.90 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
" 85 | \n",
" Second class | \n",
" female | \n",
" 17.0 | \n",
" 0 | \n",
" 0 | \n",
" 10.50 | \n",
" Southampton | \n",
" Miss | \n",
"
\n",
" \n",
" 681 | \n",
" Third class | \n",
" female | \n",
" 29.5 | \n",
" 0 | \n",
" 0 | \n",
" 8.14 | \n",
" Queenstown | \n",
" Miss | \n",
"
\n",
" \n",
" 536 | \n",
" Second class | \n",
" female | \n",
" 7.0 | \n",
" 0 | \n",
" 2 | \n",
" 26.25 | \n",
" Southampton | \n",
" Miss | \n",
"
\n",
" \n",
" 624 | \n",
" Third class | \n",
" male | \n",
" 21.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.85 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
" 149 | \n",
" Second class | \n",
" male | \n",
" 36.5 | \n",
" 0 | \n",
" 2 | \n",
" 26.00 | \n",
" Southampton | \n",
" Mr | \n",
"
\n",
" \n",
" 4 | \n",
" First class | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 53.10 | \n",
" Southampton | \n",
" Mrs | \n",
"
\n",
" \n",
" 35 | \n",
" First class | \n",
" male | \n",
" 28.0 | \n",
" 1 | \n",
" 0 | \n",
" 82.17 | \n",
" Cherbourg | \n",
" Mr | \n",
"
\n",
" \n",
" 242 | \n",
" Third class | \n",
" female | \n",
" 29.5 | \n",
" 1 | \n",
" 0 | \n",
" 15.50 | \n",
" Queenstown | \n",
" Miss | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Pclass Sex Age SibSp Parch Fare Embarked \\\n",
"PassengerId \n",
"863 First class female 48.0 0 0 25.93 Southampton \n",
"224 Third class male 29.5 0 0 7.90 Southampton \n",
"85 Second class female 17.0 0 0 10.50 Southampton \n",
"681 Third class female 29.5 0 0 8.14 Queenstown \n",
"536 Second class female 7.0 0 2 26.25 Southampton \n",
"624 Third class male 21.0 0 0 7.85 Southampton \n",
"149 Second class male 36.5 0 2 26.00 Southampton \n",
"4 First class female 35.0 1 0 53.10 Southampton \n",
"35 First class male 28.0 1 0 82.17 Cherbourg \n",
"242 Third class female 29.5 1 0 15.50 Queenstown \n",
"\n",
" Title \n",
"PassengerId \n",
"863 Mrs \n",
"224 Mr \n",
"85 Miss \n",
"681 Miss \n",
"536 Miss \n",
"624 Mr \n",
"149 Mr \n",
"4 Mrs \n",
"35 Mr \n",
"242 Miss "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xpl.x_init"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Pclass_1 | \n",
" Pclass_2 | \n",
" Pclass_3 | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Embarked | \n",
" Title | \n",
"
\n",
" \n",
" PassengerId | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 863 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0.742038 | \n",
" 48.0 | \n",
" 0 | \n",
" 0 | \n",
" 25.93 | \n",
" 1 | \n",
" 2 | \n",
"
\n",
" \n",
" 224 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0.188908 | \n",
" 29.5 | \n",
" 0 | \n",
" 0 | \n",
" 7.90 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 85 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0.742038 | \n",
" 17.0 | \n",
" 0 | \n",
" 0 | \n",
" 10.50 | \n",
" 1 | \n",
" 3 | \n",
"
\n",
" \n",
" 681 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0.742038 | \n",
" 29.5 | \n",
" 0 | \n",
" 0 | \n",
" 8.14 | \n",
" 3 | \n",
" 3 | \n",
"
\n",
" \n",
" 536 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0.742038 | \n",
" 7.0 | \n",
" 0 | \n",
" 2 | \n",
" 26.25 | \n",
" 1 | \n",
" 3 | \n",
"
\n",
" \n",
" 624 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0.188908 | \n",
" 21.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.85 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 149 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0.188908 | \n",
" 36.5 | \n",
" 0 | \n",
" 2 | \n",
" 26.00 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0.742038 | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 53.10 | \n",
" 1 | \n",
" 2 | \n",
"
\n",
" \n",
" 35 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0.188908 | \n",
" 28.0 | \n",
" 1 | \n",
" 0 | \n",
" 82.17 | \n",
" 2 | \n",
" 1 | \n",
"
\n",
" \n",
" 242 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0.742038 | \n",
" 29.5 | \n",
" 1 | \n",
" 0 | \n",
" 15.50 | \n",
" 3 | \n",
" 3 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Pclass_1 Pclass_2 Pclass_3 Sex Age SibSp Parch \\\n",
"PassengerId \n",
"863 0 1 0 0.742038 48.0 0 0 \n",
"224 1 0 0 0.188908 29.5 0 0 \n",
"85 0 0 1 0.742038 17.0 0 0 \n",
"681 1 0 0 0.742038 29.5 0 0 \n",
"536 0 0 1 0.742038 7.0 0 2 \n",
"624 1 0 0 0.188908 21.0 0 0 \n",
"149 0 0 1 0.188908 36.5 0 2 \n",
"4 0 1 0 0.742038 35.0 1 0 \n",
"35 0 1 0 0.188908 28.0 1 0 \n",
"242 1 0 0 0.742038 29.5 1 0 \n",
"\n",
" Fare Embarked Title \n",
"PassengerId \n",
"863 25.93 1 2 \n",
"224 7.90 1 1 \n",
"85 10.50 1 3 \n",
"681 8.14 3 3 \n",
"536 26.25 1 3 \n",
"624 7.85 1 1 \n",
"149 26.00 1 1 \n",
"4 53.10 1 2 \n",
"35 82.17 2 1 \n",
"242 15.50 3 3 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xpl.x_encoded"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}