ColumnTransformer tutorial¶
This tutorial shows how to use ColumnTransformer to reverse data preprocessing and display explicit labels
We used Kaggle’s Titanic dataset
Content : - Encode data with ColumnTransformer - Build a Binary Classifier (Random Forest) - Using Shapash - Show inversed data
We implement an inverse transform function for ColumnTransformer based on column position.
The top-Transform feature obtained after the ColumnTransformer shouldn’t be sampled.
[ ]:
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
Load titanic Data¶
[2]:
from shapash.data.data_loader import data_loading
titan_df, titan_dict = data_loading('titanic')
del titan_df['Name']
[3]:
titan_df.head()
[3]:
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | Title | |
---|---|---|---|---|---|---|---|---|---|
PassengerId | |||||||||
1 | 0 | Third class | male | 22.0 | 1 | 0 | 7.25 | Southampton | Mr |
2 | 1 | First class | female | 38.0 | 1 | 0 | 71.28 | Cherbourg | Mrs |
3 | 1 | Third class | female | 26.0 | 0 | 0 | 7.92 | Southampton | Miss |
4 | 1 | First class | female | 35.0 | 1 | 0 | 53.10 | Southampton | Mrs |
5 | 0 | Third class | male | 35.0 | 0 | 0 | 8.05 | Southampton | Mr |
Prepare data for the model¶
Create Target
[4]:
y = titan_df.reset_index(drop=True)['Survived']
X = titan_df.drop('Survived', axis=1)
[5]:
titan_df.reset_index(drop=True)['Survived']
[5]:
0 0
1 1
2 1
3 1
4 0
..
886 0
887 1
888 0
889 1
890 0
Name: Survived, Length: 891, dtype: int64
Train a columns transformer with multiple transformers
[6]:
enc_columntransfo = ColumnTransformer(
transformers=[
('onehot', OneHotEncoder(), ['Pclass','Sex']),
('ordinal', OrdinalEncoder(), ['Embarked','Title'])
],
remainder='passthrough')
X_transform = pd.DataFrame(enc_columntransfo.fit_transform(X, y))
Reaffect columns name for the remainder part.
[7]:
#find index that didn't get transformation
idx_col = enc_columntransfo.transformers_[2][2]
#give the N-last index, the remainder index name
start = len(X_transform.columns)-len(idx_col)
X_transform.columns = X_transform.columns.tolist()[:start]+X.columns[idx_col].tolist()
X_transform.head(2)
[7]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 2.0 | 11.0 | 22.0 | 1.0 | 0.0 | 7.25 |
1 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 12.0 | 38.0 | 1.0 | 0.0 | 71.28 |
Fit a model¶
[8]:
Xtrain, Xtest, ytrain, ytest = train_test_split(X_transform, y, train_size=0.75, random_state=1)
clf = XGBClassifier(n_estimators=200,min_child_weight=2).fit(Xtrain,ytrain)
clf.fit(Xtrain, ytrain)
[8]:
XGBClassifier(base_score=0.5, booster=None, colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints=None,
learning_rate=0.300000012, max_delta_step=0, max_depth=6,
min_child_weight=2, missing=nan, monotone_constraints=None,
n_estimators=200, n_jobs=0, num_parallel_tree=1, random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
tree_method=None, validate_parameters=False, verbosity=None)
Using Shapash¶
[9]:
from shapash import SmartExplainer
[10]:
xpl = SmartExplainer(model=clf, preprocessing=enc_columntransfo)
[11]:
xpl.compile(x=Xtest,
y_target=ytest, # Optional: allows to display True Values vs Predicted Values
)
Visualize data in pandas¶
[12]:
#Cause in ColumnsTransformer we apply multiple transformer on the same column.
#the Pclass column is now : TransformersName + Pclass
xpl.x_init.head(4)
[12]:
onehot_Pclass | onehot_Sex | ordinal_Embarked | ordinal_Title | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|---|
862 | First class | female | Southampton | Mrs | 48.0 | 0.0 | 0.0 | 25.93 |
223 | Third class | male | Southampton | Mr | 29.5 | 0.0 | 0.0 | 7.90 |
84 | Second class | female | Southampton | Miss | 17.0 | 0.0 | 0.0 | 10.50 |
680 | Third class | female | Queenstown | Miss | 29.5 | 0.0 | 0.0 | 8.14 |
[13]:
xpl.x_encoded.head(4)
[13]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|---|---|---|---|
862 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 12.0 | 48.0 | 0.0 | 0.0 | 25.93 |
223 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 2.0 | 11.0 | 29.5 | 0.0 | 0.0 | 7.90 |
84 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 2.0 | 8.0 | 17.0 | 0.0 | 0.0 | 10.50 |
680 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 1.0 | 8.0 | 29.5 | 0.0 | 0.0 | 8.14 |