{ "cells": [ { "cell_type": "markdown", "id": "bc074d5d", "metadata": {}, "source": [ "# Shapash + Keras in Jupyter: Titanic Survival Classification\n", "\n", "This tutorial shows how to:\n", "- train a tabular deep learning model (Keras)\n", "- predict Titanic passenger survival\n", "- explain predictions in a notebook with Shapash" ] }, { "cell_type": "markdown", "id": "3d7d57dd", "metadata": {}, "source": [ "## 1. Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "31d554ff", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "from IPython.display import clear_output, display\n", "\n", "from category_encoders import one_hot\n", "from sklearn.metrics import accuracy_score, classification_report\n", "from sklearn.model_selection import train_test_split\n", "\n", "import keras\n", "from keras import layers\n", "\n", "from shapash import SmartExplainer\n", "from shapash.data.data_loader import data_loading" ] }, { "cell_type": "markdown", "id": "10183481", "metadata": {}, "source": [ "## 2. Loading the Titanic data" ] }, { "cell_type": "code", "execution_count": 2, "id": "e19f182e", "metadata": {}, "outputs": [ { "data": { "application/vnd.microsoft.datawrangler.viewer.v0+json": { "columns": [ { "name": "PassengerId", "rawType": "int64", "type": "integer" }, { "name": "Survived", "rawType": "int64", "type": "integer" }, { "name": "Pclass", "rawType": "object", "type": "string" }, { "name": "Name", "rawType": "object", "type": "string" }, { "name": "Sex", "rawType": "object", "type": "string" }, { "name": "Age", "rawType": "float64", "type": "float" }, { "name": "SibSp", "rawType": "int64", "type": "integer" }, { "name": "Parch", "rawType": "int64", "type": "integer" }, { "name": "Fare", "rawType": "float64", "type": "float" }, { "name": "Embarked", "rawType": "object", "type": "string" }, { "name": "Title", "rawType": "object", "type": "string" } ], "ref": "ca3ebc4f-a0c9-4d54-a2b7-ba0b38f7bdb1", "rows": [ [ "1", "0", "Third class", "Braund Owen Harris", "male", "22.0", "1", "0", "7.25", "Southampton", "Mr" ], [ "2", "1", "First class", "Cumings John Bradley (Florence Briggs Thayer)", "female", "38.0", "1", "0", "71.28", "Cherbourg", "Mrs" ], [ "3", "1", "Third class", "Heikkinen Laina", "female", "26.0", "0", "0", "7.92", "Southampton", "Miss" ], [ "4", "1", "First class", "Futrelle Jacques Heath (Lily May Peel)", "female", "35.0", "1", "0", "53.1", "Southampton", "Mrs" ], [ "5", "0", "Third class", "Allen William Henry", "male", "35.0", "0", "0", "8.05", "Southampton", "Mr" ] ], "shape": { "columns": 10, "rows": 5 } }, "text/html": [ "
| \n", " | Survived | \n", "Pclass | \n", "Name | \n", "Sex | \n", "Age | \n", "SibSp | \n", "Parch | \n", "Fare | \n", "Embarked | \n", "Title | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| PassengerId | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
| 1 | \n", "0 | \n", "Third class | \n", "Braund Owen Harris | \n", "male | \n", "22.0 | \n", "1 | \n", "0 | \n", "7.25 | \n", "Southampton | \n", "Mr | \n", "
| 2 | \n", "1 | \n", "First class | \n", "Cumings John Bradley (Florence Briggs Thayer) | \n", "female | \n", "38.0 | \n", "1 | \n", "0 | \n", "71.28 | \n", "Cherbourg | \n", "Mrs | \n", "
| 3 | \n", "1 | \n", "Third class | \n", "Heikkinen Laina | \n", "female | \n", "26.0 | \n", "0 | \n", "0 | \n", "7.92 | \n", "Southampton | \n", "Miss | \n", "
| 4 | \n", "1 | \n", "First class | \n", "Futrelle Jacques Heath (Lily May Peel) | \n", "female | \n", "35.0 | \n", "1 | \n", "0 | \n", "53.10 | \n", "Southampton | \n", "Mrs | \n", "
| 5 | \n", "0 | \n", "Third class | \n", "Allen William Henry | \n", "male | \n", "35.0 | \n", "0 | \n", "0 | \n", "8.05 | \n", "Southampton | \n", "Mr | \n", "
| \n", " | Survived | \n", "feature_1 | \n", "value_1 | \n", "contribution_1 | \n", "feature_2 | \n", "value_2 | \n", "contribution_2 | \n", "feature_3 | \n", "value_3 | \n", "contribution_3 | \n", "... | \n", "contribution_4 | \n", "feature_5 | \n", "value_5 | \n", "contribution_5 | \n", "feature_6 | \n", "value_6 | \n", "contribution_6 | \n", "feature_7 | \n", "value_7 | \n", "contribution_7 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 735 | \n", "Deceased | \n", "Sex | \n", "male | \n", "0.127684 | \n", "Passenger fare | \n", "13.0 | \n", "0.050851 | \n", "Relatives such as brother or wife | \n", "0 | \n", "-0.028539 | \n", "... | \n", "0.014882 | \n", "Age | \n", "23.0 | \n", "-0.010472 | \n", "Ticket class | \n", "2 | \n", "-0.007397 | \n", "Relatives like children or parents | \n", "0 | \n", "-0.001051 | \n", "
| 625 | \n", "Deceased | \n", "Sex | \n", "male | \n", "0.116236 | \n", "Relatives such as brother or wife | \n", "0 | \n", "-0.028845 | \n", "Ticket class | \n", "3 | \n", "0.025976 | \n", "... | \n", "0.021563 | \n", "Age | \n", "21.0 | \n", "-0.017173 | \n", "Port of embarkation | \n", "Southampton | \n", "0.014301 | \n", "Relatives like children or parents | \n", "0 | \n", "-0.000764 | \n", "
| 104 | \n", "Deceased | \n", "Sex | \n", "male | \n", "0.125619 | \n", "Passenger fare | \n", "8.65 | \n", "0.088861 | \n", "Relatives such as brother or wife | \n", "0 | \n", "-0.025193 | \n", "... | \n", "0.020886 | \n", "Port of embarkation | \n", "Southampton | \n", "0.013359 | \n", "Age | \n", "33.0 | \n", "0.012455 | \n", "Relatives like children or parents | \n", "0 | \n", "-0.000961 | \n", "
| 388 | \n", "Survived | \n", "Sex | \n", "female | \n", "0.286759 | \n", "Relatives such as brother or wife | \n", "0 | \n", "0.032435 | \n", "Passenger fare | \n", "13.0 | \n", "-0.023769 | \n", "... | \n", "-0.014274 | \n", "Ticket class | \n", "2 | \n", "0.011565 | \n", "Relatives like children or parents | \n", "0 | \n", "0.006016 | \n", "Age | \n", "36.0 | \n", "0.002935 | \n", "
| 342 | \n", "Survived | \n", "Sex | \n", "female | \n", "0.206648 | \n", "Relatives such as brother or wife | \n", "3 | \n", "-0.121622 | \n", "Passenger fare | \n", "263.0 | \n", "0.106484 | \n", "... | \n", "0.015445 | \n", "Port of embarkation | \n", "Southampton | \n", "-0.010942 | \n", "Age | \n", "24.0 | \n", "-0.008412 | \n", "Relatives like children or parents | \n", "2 | \n", "0.005304 | \n", "
| 352 | \n", "Deceased | \n", "Sex | \n", "male | \n", "0.115636 | \n", "Passenger fare | \n", "35.0 | \n", "-0.064915 | \n", "Ticket class | \n", "1 | \n", "-0.048789 | \n", "... | \n", "-0.031078 | \n", "Port of embarkation | \n", "Southampton | \n", "0.011676 | \n", "Relatives like children or parents | \n", "0 | \n", "-0.004565 | \n", "Age | \n", "29.5 | \n", "0.002919 | \n", "
| 367 | \n", "Survived | \n", "Sex | \n", "female | \n", "0.244876 | \n", "Passenger fare | \n", "75.25 | \n", "0.170982 | \n", "Ticket class | \n", "1 | \n", "0.05393 | \n", "... | \n", "0.038142 | \n", "Age | \n", "60.0 | \n", "-0.01554 | \n", "Relatives such as brother or wife | \n", "1 | \n", "-0.008888 | \n", "Relatives like children or parents | \n", "0 | \n", "0.004335 | \n", "
| 296 | \n", "Deceased | \n", "Sex | \n", "male | \n", "0.10905 | \n", "Ticket class | \n", "1 | \n", "-0.045848 | \n", "Passenger fare | \n", "27.72 | \n", "-0.035266 | \n", "... | \n", "-0.030661 | \n", "Port of embarkation | \n", "Cherbourg | \n", "-0.030228 | \n", "Relatives like children or parents | \n", "0 | \n", "-0.004339 | \n", "Age | \n", "29.5 | \n", "0.004128 | \n", "
| 428 | \n", "Survived | \n", "Sex | \n", "female | \n", "0.242509 | \n", "Relatives such as brother or wife | \n", "0 | \n", "0.040903 | \n", "Passenger fare | \n", "26.0 | \n", "0.026343 | \n", "... | \n", "0.023752 | \n", "Port of embarkation | \n", "Southampton | \n", "-0.011517 | \n", "Ticket class | \n", "2 | \n", "0.011443 | \n", "Relatives like children or parents | \n", "0 | \n", "0.004754 | \n", "
| 824 | \n", "Survived | \n", "Sex | \n", "female | \n", "0.266229 | \n", "Relatives such as brother or wife | \n", "0 | \n", "0.03253 | \n", "Passenger fare | \n", "12.48 | \n", "-0.031449 | \n", "... | \n", "-0.025467 | \n", "Port of embarkation | \n", "Southampton | \n", "-0.013692 | \n", "Relatives like children or parents | \n", "1 | \n", "-0.008707 | \n", "Age | \n", "27.0 | \n", "0.005659 | \n", "
10 rows × 22 columns
\n", "