Source code for shapash.data.data_loader
"""
Data loader module
"""
import json
import os
from pathlib import Path
from urllib.error import URLError
from urllib.parse import urlparse
from urllib.request import urlopen
import pandas as pd
def _safe_urlopen(url: str, **kwargs):
"""Open an HTTPS URL safely with forwarded request options.
Parameters
----------
url : str
URL to open. Only HTTPS schemes are allowed.
**kwargs
Keyword arguments forwarded to :func:`urllib.request.urlopen`.
Returns
-------
file-like
The response object returned by :func:`urllib.request.urlopen`.
Raises
------
ValueError
If the URL scheme is not HTTPS.
"""
if urlparse(url).scheme != "https":
raise ValueError(f"Only HTTPS URLs are permitted: {url}")
return urlopen(url, **kwargs) # noqa: S310
def _find_file(data_path, github_data_url, filename):
"""
Finds file path on disk if it exists or gets file path on github.
Parameters
----------
data_path : str
Data folder path
github_data_url : str
Github data url
filename : str
Name of the file
Returns
-------
str
Founded file path.
"""
file = os.path.join(data_path, filename)
if os.path.isfile(file) is False:
file = github_data_url + filename
try:
with _safe_urlopen(file, timeout=10):
pass
except URLError as exc:
raise ConnectionError(f"Internet connection is required to download: {file}") from exc
return file
[docs]def data_loading(dataset):
"""
data_loading allows shapash user to try the library with small but clear datasets.
Titanic, house_prices or telco_customer_churn data.
Example
----------
>>> from shapash.data.data_loader import data_loading
>>> house_df, house_dict = data_loading('house_prices')
Parameters
----------
dataset : String
Dataset's name to return.
- 'titanic'
- 'house_prices'
- 'telco_customer_churn'
Returns
-------
data : pandas.DataFrame
Dataset required
dict : (Dictionnary, Optional)
If exist, columns labels dictionnary associated to the dataset.
"""
data_path = str(Path(__file__).parents[2] / "data")
if dataset == "house_prices":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_house_prices_path = _find_file(data_path, github_data_url, "house_prices_dataset.csv")
dict_house_prices_path = _find_file(data_path, github_data_url, "house_prices_labels.json")
data = pd.read_csv(data_house_prices_path, header=0, index_col=0, engine="python")
if github_data_url in dict_house_prices_path:
with _safe_urlopen(dict_house_prices_path) as openfile:
dic = json.load(openfile)
else:
with open(dict_house_prices_path) as openfile:
dic = json.load(openfile)
return data, dic
elif dataset == "titanic":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_titanic_path = _find_file(data_path, github_data_url, "titanicdata.csv")
dict_titanic_path = _find_file(data_path, github_data_url, "titaniclabels.json")
data = pd.read_csv(data_titanic_path, header=0, index_col=0, engine="python")
if github_data_url in dict_titanic_path:
with _safe_urlopen(dict_titanic_path) as openfile:
dic = json.load(openfile)
else:
with open(dict_titanic_path) as openfile:
dic = json.load(openfile)
return data, dic
elif dataset == "telco_customer_churn":
github_data_url = "https://github.com/IBM/telco-customer-churn-on-icp4d/raw/master/data/"
data_telco_path = _find_file(data_path, github_data_url, "Telco-Customer-Churn.csv")
data = pd.read_csv(data_telco_path, header=0, index_col=0, engine="python")
return data
elif dataset == "us_car_accident":
github_data_url = "https://github.com/MAIF/shapash/raw/master/data/"
data_accidents_path = _find_file(data_path, github_data_url, "US_Accidents_extract.csv")
data = pd.read_csv(data_accidents_path, header=0, engine="python")
return data
else:
raise ValueError("Dataset not found. Check the docstring for available values")