Source code for shapash.data.data_loader
"""
Data loader module
"""
import json
import os
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen
import pandas as pd
def _find_file(data_path, github_data_url, filename):
"""
Finds file path on disk if it exists or gets file path on github.
Parameters
----------
data_path : str
Data folder path
github_data_url : str
Github data url
filename : str
Name of the file
Returns
-------
str
Founded file path.
"""
file = os.path.join(data_path, filename)
if os.path.isfile(file) is False:
file = github_data_url + filename
try:
urlopen(file)
except URLError:
raise Exception(f"Internet connection is required to download: {file}")
return file
[docs]def data_loading(dataset):
"""
data_loading allows shapash user to try the library with small but clear datasets.
Titanic, house_prices or telco_customer_churn data.
Example
----------
>>> from shapash.data.data_loader import data_loading
>>> house_df, house_dict = data_loading('house_prices')
Parameters
----------
dataset : String
Dataset's name to return.
- 'titanic'
- 'house_prices'
- 'telco_customer_churn'
Returns
-------
data : pandas.DataFrame
Dataset required
dict : (Dictionnary, Optional)
If exist, columns labels dictionnary associated to the dataset.
"""
data_path = str(Path(__file__).parents[2] / "data")
if dataset == "house_prices":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_house_prices_path = _find_file(data_path, github_data_url, "house_prices_dataset.csv")
dict_house_prices_path = _find_file(data_path, github_data_url, "house_prices_labels.json")
data = pd.read_csv(data_house_prices_path, header=0, index_col=0, engine="python")
if github_data_url in dict_house_prices_path:
with urlopen(dict_house_prices_path) as openfile:
dic = json.load(openfile)
else:
with open(dict_house_prices_path) as openfile:
dic = json.load(openfile)
return data, dic
elif dataset == "titanic":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_titanic_path = _find_file(data_path, github_data_url, "titanicdata.csv")
dict_titanic_path = _find_file(data_path, github_data_url, "titaniclabels.json")
data = pd.read_csv(data_titanic_path, header=0, index_col=0, engine="python")
if github_data_url in dict_titanic_path:
with urlopen(dict_titanic_path) as openfile:
dic = json.load(openfile)
else:
with open(dict_titanic_path) as openfile:
dic = json.load(openfile)
return data, dic
elif dataset == "telco_customer_churn":
github_data_url = "https://github.com/IBM/telco-customer-churn-on-icp4d/raw/master/data/"
data_telco_path = _find_file(data_path, github_data_url, "Telco-Customer-Churn.csv")
data = pd.read_csv(data_telco_path, header=0, index_col=0, engine="python")
return data
elif dataset == "us_car_accident":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_accidents_path = _find_file(data_path, github_data_url, "US_Accidents_extract.csv")
data = pd.read_csv(data_accidents_path, header=0, engine="python")
return data
else:
raise ValueError("Dataset not found. Check the docstring for available values")