{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Category_encoder tutorial\n", "\n", "This tutorial shows how to use category_encoder encoders to reverse data preprocessing and display explicit labels.\n", "\n", "We used Kaggle's [Titanic](https://www.kaggle.com/c/titanic) dataset.\n", "\n", "This Tutorial:\n", "- Encode data with Category_encoder\n", "- Build a Binary Classifier (Random Forest)\n", "- Using Shapash\n", "- Show inversed data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from category_encoders import OrdinalEncoder\n", "from category_encoders import OneHotEncoder\n", "from category_encoders import TargetEncoder\n", "from xgboost import XGBClassifier\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load titanic Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from shapash.data.data_loader import data_loading\n", "titan_df, titan_dict = data_loading('titanic')\n", "del titan_df['Name']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SurvivedPclassSexAgeSibSpParchFareEmbarkedTitle
PassengerId
10Third classmale22.0107.25SouthamptonMr
21First classfemale38.01071.28CherbourgMrs
31Third classfemale26.0007.92SouthamptonMiss
41First classfemale35.01053.10SouthamptonMrs
50Third classmale35.0008.05SouthamptonMr
\n", "
" ], "text/plain": [ " Survived Pclass Sex Age SibSp Parch Fare \\\n", "PassengerId \n", "1 0 Third class male 22.0 1 0 7.25 \n", "2 1 First class female 38.0 1 0 71.28 \n", "3 1 Third class female 26.0 0 0 7.92 \n", "4 1 First class female 35.0 1 0 53.10 \n", "5 0 Third class male 35.0 0 0 8.05 \n", "\n", " Embarked Title \n", "PassengerId \n", "1 Southampton Mr \n", "2 Cherbourg Mrs \n", "3 Southampton Miss \n", "4 Southampton Mrs \n", "5 Southampton Mr " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "titan_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare data for the model with Category Encoder\n", "\n", "Create Target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "y = titan_df['Survived']\n", "X = titan_df.drop('Survived', axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train category encoder" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#Train category encoder\n", "onehot = OneHotEncoder(cols=['Pclass']).fit(X)\n", "result_1 = onehot.transform(X)\n", "ordinal = OrdinalEncoder(cols=['Embarked','Title']).fit(result_1)\n", "result_2 = ordinal.transform(result_1)\n", "target = TargetEncoder(cols=['Sex']).fit(result_2,y)\n", "result_3 =target.transform(result_2)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "encoder = [onehot,ordinal,target]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit a model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,\n", " importance_type='gain', interaction_constraints=None,\n", " learning_rate=0.300000012, max_delta_step=0, max_depth=6,\n", " min_child_weight=2, missing=nan, monotone_constraints=None,\n", " n_estimators=200, n_jobs=0, num_parallel_tree=1,\n", " objective='binary:logistic', random_state=0, reg_alpha=0,\n", " reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method=None,\n", " validate_parameters=False, verbosity=None)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Xtrain, Xtest, ytrain, ytest = train_test_split(result_3, y, train_size=0.75, random_state=1)\n", "\n", "clf = XGBClassifier(n_estimators=200,min_child_weight=2).fit(Xtrain,ytrain)\n", "clf.fit(Xtrain, ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Shapash" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from shapash import SmartExplainer" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "xpl = SmartExplainer(\n", " model=clf,\n", " preprocessing=encoder,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Backend: Shap TreeExplainer\n" ] } ], "source": [ "xpl.compile(x=Xtest,\n", "y_target=ytest, # Optional: allows to display True Values vs Predicted Values\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize data in pandas" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PclassSexAgeSibSpParchFareEmbarkedTitle
PassengerId
863First classfemale48.00025.93SouthamptonMrs
224Third classmale29.5007.90SouthamptonMr
85Second classfemale17.00010.50SouthamptonMiss
681Third classfemale29.5008.14QueenstownMiss
536Second classfemale7.00226.25SouthamptonMiss
624Third classmale21.0007.85SouthamptonMr
149Second classmale36.50226.00SouthamptonMr
4First classfemale35.01053.10SouthamptonMrs
35First classmale28.01082.17CherbourgMr
242Third classfemale29.51015.50QueenstownMiss
\n", "
" ], "text/plain": [ " Pclass Sex Age SibSp Parch Fare Embarked \\\n", "PassengerId \n", "863 First class female 48.0 0 0 25.93 Southampton \n", "224 Third class male 29.5 0 0 7.90 Southampton \n", "85 Second class female 17.0 0 0 10.50 Southampton \n", "681 Third class female 29.5 0 0 8.14 Queenstown \n", "536 Second class female 7.0 0 2 26.25 Southampton \n", "624 Third class male 21.0 0 0 7.85 Southampton \n", "149 Second class male 36.5 0 2 26.00 Southampton \n", "4 First class female 35.0 1 0 53.10 Southampton \n", "35 First class male 28.0 1 0 82.17 Cherbourg \n", "242 Third class female 29.5 1 0 15.50 Queenstown \n", "\n", " Title \n", "PassengerId \n", "863 Mrs \n", "224 Mr \n", "85 Miss \n", "681 Miss \n", "536 Miss \n", "624 Mr \n", "149 Mr \n", "4 Mrs \n", "35 Mr \n", "242 Miss " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xpl.x_init" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Pclass_1Pclass_2Pclass_3SexAgeSibSpParchFareEmbarkedTitle
PassengerId
8630100.74203848.00025.9312
2241000.18890829.5007.9011
850010.74203817.00010.5013
6811000.74203829.5008.1433
5360010.7420387.00226.2513
6241000.18890821.0007.8511
1490010.18890836.50226.0011
40100.74203835.01053.1012
350100.18890828.01082.1721
2421000.74203829.51015.5033
\n", "
" ], "text/plain": [ " Pclass_1 Pclass_2 Pclass_3 Sex Age SibSp Parch \\\n", "PassengerId \n", "863 0 1 0 0.742038 48.0 0 0 \n", "224 1 0 0 0.188908 29.5 0 0 \n", "85 0 0 1 0.742038 17.0 0 0 \n", "681 1 0 0 0.742038 29.5 0 0 \n", "536 0 0 1 0.742038 7.0 0 2 \n", "624 1 0 0 0.188908 21.0 0 0 \n", "149 0 0 1 0.188908 36.5 0 2 \n", "4 0 1 0 0.742038 35.0 1 0 \n", "35 0 1 0 0.188908 28.0 1 0 \n", "242 1 0 0 0.742038 29.5 1 0 \n", "\n", " Fare Embarked Title \n", "PassengerId \n", "863 25.93 1 2 \n", "224 7.90 1 1 \n", "85 10.50 1 3 \n", "681 8.14 3 3 \n", "536 26.25 1 3 \n", "624 7.85 1 1 \n", "149 26.00 1 1 \n", "4 53.10 1 2 \n", "35 82.17 2 1 \n", "242 15.50 3 3 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xpl.x_encoded" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }