Module 2 advanced modules ebook 1
The Modern Data Science Workflow: From Efficient Analysis to High-Performance Modeling
Part 0: Setting the Stage
Chapter 1: Beyond the "Big Four"
1.1 You've Built the Foundation
Congratulations! If you're reading this, it means you've already embarked on the exciting journey of data science and have likely mastered the foundational Python libraries: NumPy, Pandas, Matplotlib, and Seaborn. You understand how to manipulate numerical data with NumPy, wrangle and clean tabular data with Pandas, and create compelling visualizations with Matplotlib and Seaborn.
1.2 The Modern Data Science House
Just as a house needs more than just a foundation to be truly functional and comfortable, a modern data science workflow extends far beyond the basics. While the
Big Four are indispensable, the demands of real-world data science projects often require additional capabilities:
- Efficient Plumbing (Automation): Automating repetitive tasks, especially in data exploration and reporting, saves immense time and reduces human error. Imagine generating comprehensive data summaries with just a few lines of code, rather than manually running dozens of commands.
- Smart Windows (Interactivity): Static plots are good, but interactive visualizations allow for deeper exploration, enabling stakeholders to drill down into data, filter, and highlight points of interest. This transforms data presentation from a monologue into a dialogue.
- A Reinforced Structure (Performance): As datasets grow larger, the traditional in-memory processing of Pandas can hit limitations, leading to
MemoryError
crashes or painfully slow computations. We need strategies and tools to handle data that doesn't fit into RAM and to accelerate operations on large datasets.
This book is about equipping you with these advanced capabilities, transforming your foundational knowledge into a truly modern and high-performance data science workflow. We will explore specialized libraries that address these challenges, allowing you to tackle more complex problems with greater efficiency and impact.
1.3 The Two Tiers of Mastery
To navigate the landscape of advanced data science tools, we can categorize them into two tiers, building upon your existing foundation:
-
Tier 2 (The Modern Workflow Enhancers): These are tools that significantly boost your efficiency, interactivity, and the depth of your initial analysis. They streamline common tasks and enhance your ability to communicate insights. They are designed to integrate seamlessly with your existing Pandas and Matplotlib workflows, often providing higher-level abstractions or automated functionalities.
- ydata-profiling: For automated, comprehensive Exploratory Data Analysis (EDA) reports.
- Plotly/Plotly Express: For creating interactive and visually appealing data visualizations.
- SciPy (specifically
scipy.stats
): For statistical inference and hypothesis testing, moving beyond mere description to rigorous analysis.
-
Tier 3 (The Specialized Power Tools): These tools are for when you hit the inherent limits of traditional libraries, particularly concerning data size and computational speed. They introduce new paradigms for handling
large datasets and optimizing performance, often leveraging parallel processing or memory-efficient data structures.
* Dask: For handling datasets that are too large to fit into memory, enabling parallel computing with a familiar Pandas-like API.
* Polars: A new, highly performant DataFrame library designed for speed and memory efficiency, often outperforming Pandas on large datasets.
* Statsmodels: For in-depth statistical modeling and understanding the explanatory power of your variables, complementing Scikit-learn's predictive focus.
This book will guide you through each of these tools, explaining their purpose, how to use them, and, crucially, when to choose one over another. We will emphasize practical application, building on your existing knowledge of the
Big Four and showing how these new tools integrate into a cohesive data science workflow.
1.4 Introducing Our Core Datasets
To ensure consistency and provide practical, hands-on examples throughout this book, we will primarily work with a few core datasets. These datasets have been chosen to represent common data science challenges and to effectively demonstrate the capabilities of the libraries we will be exploring:
- E-commerce Transaction Dataset: A messy dataset containing customer transactions, product details, and timestamps. This dataset will be used to illustrate data cleaning, feature engineering, and funnel analysis. It will highlight scenarios where automated EDA and interactive visualizations are particularly useful.
- Customer Churn Dataset: A dataset containing customer demographics, service usage, and a binary target variable indicating whether a customer churned (cancelled their service). This dataset is ideal for demonstrating statistical hypothesis testing and explanatory modeling.
- Large Web Server Log File Dataset: A simulated large log file (e.g., 15GB) that cannot fit into typical computer memory. This dataset will be crucial for showcasing the power of Dask and Polars in handling big data challenges and performance optimization.
By consistently using these datasets across different chapters, you will gain a deeper understanding of how each tool contributes to a complete data science workflow, from initial exploration to advanced modeling. Let's begin our journey into the modern data science workflow!
Part 1: The Modern Workflow Enhancers (Tier 2)
Chapter 2: Automated EDA with ydata-profiling - Your First 30 Minutes, Automated
2.1 The Pain of Manual EDA
As a data scientist, you know the drill. You get a new dataset, and before you can even think about modeling, you embark on the essential, yet often repetitive, journey of Exploratory Data Analysis (EDA). This involves a series of commands that become second nature:
df.info()
: To check data types and non-null counts.df.describe()
: To get summary statistics for numerical columns.df.isnull().sum()
: To quantify missing values per column.df.head()
anddf.tail()
: For a quick visual inspection of the data.df["column"].value_counts()
: To understand the distribution of categorical variables.df["numerical_column"].hist()
ordf["numerical_column"].plot(kind=\'density\')
: To visualize numerical distributions.df.corr()
: To check correlations between numerical features.- Creating countless plots (histograms, scatter plots, box plots) to understand individual variables and their relationships.
While each of these commands is vital, the process of running them, interpreting the output, and then compiling those insights into a coherent summary can be time-consuming and tedious, especially for large datasets or when you need to quickly assess multiple datasets. It's a necessary evil, but what if there was a way to automate a significant portion of this initial legwork, freeing you up to focus on deeper, more nuanced analysis?
2.2 The Two-Line Solution
Enter ydata-profiling
(formerly pandas-profiling
), a powerful Python library that automates the generation of comprehensive EDA reports. It takes your Pandas DataFrame and, with just a couple of lines of code, produces an interactive HTML report containing a wealth of information about your data, including descriptive statistics, missing value analysis, correlations, and much more.
Installation:
First, you need to install the library. It's recommended to do this in your environment:
pip install ydata-profiling
The Core Code:
Once installed, generating a report is incredibly simple:
import pandas as pd
from ydata_profiling import ProfileReport
import os
# Create a dummy DataFrame for demonstration
data = {
"Numerical_Feature": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"Categorical_Feature": ["A", "B", "A", "C", "B", "A", "C", "B", "A", "C", "A", "B", "A", "C", "B", "A", "C", "B", "A", "C", "A"],
"Missing_Feature": [1, 2, 3, 4, 5, None, 7, 8, None, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
"Boolean_Feature": [True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, True],
"High_Cardinality_Feature": [f"ID_{i}" for i in range(21)],
"Constant_Feature": [5] * 21
}
df = pd.DataFrame(data)
# Generate the profile report
profile = ProfileReport(df, title="My Awesome Data Profile", html={"style":{"full_width":True}})
# Display the report in a Jupyter Notebook/Colab environment
# profile.to_notebook_iframe() # Uncomment to display in notebook
# Save the report to an HTML file
report_path = "my_data_profile.html"
profile.to_file(report_path)
print(f"Profile report saved to {report_path}")
# Clean up the dummy file (optional, for demonstration purposes)
# In a real scenario, you would keep the generated HTML report.
os.remove(report_path)
Code Explanation & Output:
pip install ydata-profiling
: This command installs the library. You only need to run this once in your environment.import pandas as pd
andfrom ydata_profiling import ProfileReport
: Imports the necessary Pandas library and theProfileReport
class.df = pd.DataFrame(data)
: We create a sample DataFrame with various data types and characteristics (numerical, categorical, missing values, high cardinality, constant) to demonstrate the comprehensive nature of the report.profile = ProfileReport(df, title="My Awesome Data Profile", html={"style":{"full_width":True}})
: This is the core line. You pass your DataFrame to theProfileReport
constructor. You can also provide atitle
for your report and customize its appearance (e.g.,full_width=True
for better viewing).profile.to_notebook_iframe()
: If you are working in a Jupyter Notebook or Google Colab environment, uncommenting this line will display the interactive report directly within your notebook output.profile.to_file(report_path)
: This saves the generated report as a standalone HTML file. This is incredibly useful for sharing your EDA findings with colleagues or stakeholders who might not have a Python environment set up.
# A file named 'my_data_profile.html' will be generated in your current directory.
# If run in a notebook, an interactive report will be displayed inline.
Profile report saved to my_data_profile.html
When you open the my_data_profile.html
file in your web browser, you will be greeted with a rich, interactive dashboard summarizing your dataset. This single report replaces dozens of manual commands and plots, providing a holistic view of your data quality and characteristics.
2.3 A Guided Tour of the Report (In-Depth)
The ydata-profiling
report is structured into several tabs, each providing different perspectives on your data. Let's take a detailed look at the most important sections.
The Overview Tab
This tab provides a high-level summary of your dataset, including the number of variables, observations, missing values, duplicate rows, and memory size. Crucially, it also highlights "Warnings". This section is your first stop for identifying potential data quality issues that need attention.
For each warning, ydata-profiling
explains what it means and often suggests next steps. Let's break down common warnings:
-
High Correlation:
- What it means: Two or more variables are strongly linearly related. For example,
Price
andSales_Tax
might be highly correlated. - Why it matters: In many machine learning models (especially linear models), highly correlated features can lead to multicollinearity, which can make the model unstable, harder to interpret, and sometimes reduce its predictive power. It also means one of the features might be redundant.
- Next steps: Consider removing one of the highly correlated features, or use dimensionality reduction techniques like PCA (Principal Component Analysis) if you need to retain the information from both.
- What it means: Two or more variables are strongly linearly related. For example,
-
Missing Values:
- What it means: There are
NaN
(Not a Number) entries in a column, indicating absent data points. - Why it matters: Most machine learning algorithms cannot handle missing values directly. Ignoring them can lead to errors or biased results. The
Overview
tab will tell you which columns have missing values and their percentage. - Next steps: Depending on the percentage and nature of missingness, you might choose to:
- Impute: Fill missing values with a statistical measure (mean, median, mode) or a more sophisticated method.
- Drop: Remove rows or columns with missing values (use with caution to avoid significant data loss).
- Investigate: Understand why the data is missing (e.g., not applicable, data collection error).
- What it means: There are
-
High Cardinality:
- What it means: A categorical variable has a very large number of unique values (e.g., a
CustomerID
column with millions of unique IDs). - Why it matters: High cardinality can cause issues for certain machine learning models, especially those that rely on one-hot encoding, as it can lead to a huge number of new features, increasing dimensionality and potentially causing memory issues or overfitting.
- Next steps: Consider techniques like target encoding, grouping rare categories, or feature hashing.
- What it means: A categorical variable has a very large number of unique values (e.g., a
-
Skewness:
- What it means: The distribution of a numerical variable is asymmetrical, with a long tail on one side. A positive skew means the tail is on the right (more high values), and a negative skew means the tail is on the left (more low values).
- Why it matters: Many statistical methods and machine learning algorithms assume that numerical features are normally distributed. Skewed data can violate these assumptions, leading to suboptimal model performance.
- Next steps: Apply transformations like logarithmic, square root, or Box-Cox transformations to reduce skewness and make the distribution more symmetrical.
-
Zeros:
- What it means: A numerical column contains a significant number of zero values.
- Why it matters: While not always a problem, a high proportion of zeros can sometimes indicate a sparse feature or a specific data generation process that might need special handling. For example, in a
Sales
column, many zeros might mean no sales occurred, which is valid, but in aCustomer_Rating
column, it might indicate missing data or a default value. - Next steps: Understand the context of the zeros. If they represent true absence, consider if a specific model (e.g., zero-inflated models) or feature engineering (e.g., creating a binary
indicator for presence/absence) is needed.
The Variables Tab
This tab provides a detailed, individual analysis for each variable (column) in your dataset. It's incredibly comprehensive, offering insights into data types, unique values, common values, and various statistical measures. Let's take a detailed look at how to interpret the information for both a numerical and a categorical variable.
Example: Numerical Variable (Numerical_Feature
)
When you click on a numerical variable in the Variables
tab, you'll see a detailed view that includes:
- Statistics:
- Distinct: Number of unique values.
- Distinct (%): Percentage of unique values relative to the total number of observations.
- Missing: Number of missing values (
NaN
). - Missing (%): Percentage of missing values.
- Zeros: Number of zero values.
- Zeros (%): Percentage of zero values.
- Mean: The average value.
- Standard Deviation (Std Dev): A measure of the dispersion or spread of the data around the mean.
- Minimum (Min): The smallest value.
- Maximum (Max): The largest value.
- Skewness: A measure of the asymmetry of the probability distribution of a real-valued random variable about its mean. Positive skew indicates a longer tail on the right, negative on the left.
- Kurtosis: A measure of the
tailedness of the probability distribution of a real-valued random variable. High kurtosis means more outliers.
-
Quantile Statistics:
- Minimum (Min): Smallest value.
- 5%: Value below which 5% of the data falls.
- Q1 (25%): First quartile; 25% of the data is below this value.
- Median (50%): The middle value when the data is ordered; also the second quartile.
- Q3 (75%): Third quartile; 75% of the data is below this value.
- 95%: Value below which 95% of the data falls.
- Maximum (Max): Largest value.
- Range: Max - Min.
- Interquartile Range (IQR): Q3 - Q1. The range of the middle 50% of the data.
-
Histogram: A visual representation of the distribution of the numerical variable. It shows the frequency of values within different bins. This is crucial for understanding the shape of the distribution (e.g., normal, skewed, bimodal).
-
Common Values: A table showing the most frequent values in the column and their counts/percentages. This can help identify dominant values or potential data entry errors.
-
Extreme Values: A table showing the smallest and largest values, which can help identify outliers.
Example: Categorical Variable (Categorical_Feature
)
For a categorical variable, the detailed view will include:
-
Statistics:
- Distinct: Number of unique categories.
- Distinct (%): Percentage of unique categories.
- Missing: Number and percentage of missing values.
- Warning (if any): E.g., High Cardinality.
-
Frequencies: A bar chart showing the frequency of each category. This helps visualize the distribution of categories.
-
Common Values: A table listing the most frequent categories and their counts/percentages.
Interpreting the Variables
Tab:
- Data Quality: Quickly spot missing values, zeros, and potential outliers by examining the statistics and extreme values.
- Distribution: The histogram for numerical variables and frequency chart for categorical variables provide immediate insights into how your data is distributed. This informs decisions about transformations or imputation strategies.
- Cardinality: For categorical variables, the
Distinct
count helps you understand if one-hot encoding is feasible or if other encoding strategies are needed.
The Correlations Tab
This tab visualizes the relationships between numerical variables using different correlation methods. Understanding correlations is vital for feature selection and avoiding multicollinearity.
-
Pearson Correlation:
- When to use: Measures the linear relationship between two continuous variables. It ranges from -1 (perfect negative linear correlation) to +1 (perfect positive linear correlation), with 0 indicating no linear correlation.
- Interpretation: Useful for identifying features that move together (or in opposite directions) linearly. For example,
Years_of_Experience
andSalary
might have a strong positive Pearson correlation.
-
Spearman Correlation:
- When to use: Measures the monotonic relationship between two variables (whether linear or not). It assesses how well the relationship between two variables can be described using a monotonic function. It is less sensitive to outliers than Pearson correlation.
- Interpretation: Useful when you suspect a relationship but it might not be strictly linear, or when your data contains outliers. For example, if
Customer_Satisfaction_Score
tends to increase asProduct_Usage
increases, but not necessarily in a straight line, Spearman might capture this better.
-
Kendall Correlation (Tau):
- When to use: Another non-parametric measure of the strength of dependence between two variables. It is often used as an alternative to Spearman when data is not normally distributed or has many tied ranks.
- Interpretation: Similar to Spearman, it assesses monotonic relationships and is robust to outliers.
-
Phik (φk) Correlation:
- When to use: A novel and robust correlation coefficient that works for numerical, ordinal, and nominal variables. It captures non-linear relationships and is particularly useful for mixed data types.
- Interpretation: Provides a more comprehensive view of relationships across all data types, including categorical variables, which traditional Pearson/Spearman cannot directly handle.
Interpreting the Correlation Heatmap:
The heatmap visually represents the correlation matrix. Darker colors (e.g., dark blue for positive, dark red for negative) indicate stronger correlations, while lighter colors indicate weaker ones. Hovering over cells usually reveals the exact correlation coefficient. This helps you quickly identify highly correlated features that might need attention.
The Missing Values Tab
This tab provides detailed visualizations of missing data patterns, helping you understand not just how many values are missing, but where and how they are missing.
- Count: A simple bar chart showing the number of missing values per column.
- Matrix: A visual matrix where each row is an observation and each column is a variable. Missing values are typically represented by a different color (e.g., white or light gray), while present values are dark. This helps you spot patterns of missingness (e.g., if certain rows or columns are entirely missing, or if missingness in one column is related to another).
- Bar: Similar to the count, but a bar chart.
- Dendrogram: A tree-like diagram that groups columns based on the correlation of their missingness. Columns that are frequently missing together will be clustered closer. This can indicate underlying reasons for missing data.
Interpreting the Missing Values Tab:
- Extent of Missingness: The count and bar charts quickly show which columns have the most missing data.
- Patterns of Missingness: The matrix and dendrogram are invaluable for understanding if missingness is random, or if there are systematic patterns. For example, if
Cabin
andAge
are often missing together, it might suggest a common reason.
2.4 The Limits of Automation
While ydata-profiling
is an incredibly powerful tool for accelerating your EDA, it's important to understand its limitations. It's a fantastic starting point, but it doesn't replace the need for human intelligence and domain expertise.
- Domain-Specific Validation: The tool cannot understand the business context of your data. For example, it might flag a
Customer_ID
column as high cardinality, which is technically true, but perfectly normal and expected in a customer database. It won't know if aPrice
of -5 is a data entry error or a valid return. You still need to apply your domain knowledge to interpret the warnings and decide on appropriate actions. - Complex Relationships: While it shows correlations, it won't uncover complex, non-linear relationships or interactions between many variables that might require more sophisticated statistical modeling or feature engineering.
- Time-Series Specifics: For time-series data, while it provides basic statistics, it won't perform specialized time-series analyses like seasonality decomposition or trend analysis.
- Deep Dive into Outliers: It identifies extreme values, but a deeper investigation into why those outliers exist and how they should be handled (e.g., removal, transformation, or special modeling) still requires manual effort.
- Data Cleaning Execution: The report identifies problems; it doesn't fix them. You still need to write the Pandas or NumPy code to handle missing values, transform skewed data, or address high cardinality.
In essence, ydata-profiling
is your highly efficient assistant for the initial reconnaissance mission. It highlights the areas that need attention, allowing you to quickly prioritize your manual EDA efforts and dive deeper into the most critical data quality issues. It transforms the often tedious initial 30 minutes of EDA into a highly productive and insightful experience.
Chapter 3: Interactive Storytelling with Plotly
3.1 The "I Wish I Could Click That" Moment
As you've progressed through your data science journey, you've undoubtedly created numerous compelling visualizations using Matplotlib and Seaborn. You've mastered line plots to show trends, scatter plots to reveal relationships, and bar charts to compare categories. These static plots are excellent for reports, presentations, and publications where the message is clear and the audience's interaction is limited to viewing.
However, have you ever found yourself in a situation, perhaps during an exploratory analysis session or a live presentation, where you wished your plots could do more? Imagine a scatter plot with hundreds of data points. You see a cluster of outliers or an interesting pattern, and your immediate thought is: "I wish I could click on that point to see its underlying data!" Or perhaps you're presenting a time-series chart, and a stakeholder asks, "What happened in that specific month?" – and you can't zoom in or hover to reveal the exact values.
This is the "I wish I could click that" moment. Static plots, while powerful for conveying a single, clear message, fall short when you need to explore data dynamically, allow your audience to interact with the visualization, or present complex information in a digestible, interactive format. This is where interactive visualization libraries like Plotly come into play, transforming your data stories from static images into dynamic, explorable experiences.
3.2 Plotly vs. Plotly Express (px)
Plotly is a powerful open-source graphing library that allows you to create interactive, publication-quality graphs. It supports a wide range of chart types and offers extensive customization options. However, its full API can sometimes be verbose, requiring more lines of code for common tasks.
This is where Plotly Express (px
) shines. Plotly Express is a high-level wrapper around Plotly, designed for rapid data exploration and visualization. It provides a concise, intuitive syntax that is very similar to Seaborn, allowing you to create complex interactive plots with just a few lines of code. For most common data science tasks, Plotly Express is the go-to choice due to its simplicity and efficiency.
For this chapter, we will primarily use
plotly.express
(px
) because it aligns perfectly with our goal of efficient and interactive storytelling for intermediate practitioners. It handles much of the underlying Plotly complexity for you, allowing you to focus on mapping your data to visual aesthetics.
3.3 The Core Philosophy: Mapping Columns to Aesthetics
The core philosophy behind Plotly Express, much like Seaborn, revolves around mapping columns from your DataFrame to visual aesthetics of a plot. You tell px
which column should represent the x-axis, which the y-axis, which should determine color, size, or shape, and px
handles the rest, including generating interactive legends, tooltips, and zoom/pan functionalities.
Let's see a side-by-side comparison of how px
uses a similar logic to Seaborn for a scatter plot:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
# Create a dummy DataFrame
data = {
"Feature_X": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"Feature_Y": [10, 12, 15, 13, 18, 20, 22, 25, 23, 28],
"Category": ["A", "B", "A", "C", "B", "A", "C", "B", "A", "C"],
"Size_Metric": [50, 100, 70, 120, 90, 60, 110, 80, 130, 140]
}
df = pd.DataFrame(data)
print("--- Seaborn Scatter Plot ---")
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df, x="Feature_X", y="Feature_Y", hue="Category", size="Size_Metric", sizes=(20, 200))
plt.title("Seaborn Static Scatter Plot")
plt.show()
print("\n--- Plotly Express Interactive Scatter Plot ---")
fig = px.scatter(df, x="Feature_X", y="Feature_Y", color="Category", size="Size_Metric",
title="Plotly Express Interactive Scatter Plot",
hover_data=["Category", "Size_Metric"])
fig.show()
Code Explanation & Output:
- Seaborn: We use
sns.scatterplot()
and mapFeature_X
tox
,Feature_Y
toy
,Category
tohue
(color), andSize_Metric
tosize
.sizes
controls the range of marker sizes. - Plotly Express: We use
px.scatter()
and mapFeature_X
tox
,Feature_Y
toy
,Category
tocolor
, andSize_Metric
tosize
. Notice the striking similarity in how the aesthetic mappings are defined. Thehover_data
argument is a powerful addition in Plotly Express, allowing you to specify which columns should appear in the tooltip when you hover over a data point.
--- Seaborn Scatter Plot ---
# A static Matplotlib plot will be displayed.
--- Plotly Express Interactive Scatter Plot ---
# An interactive Plotly plot will be displayed in your browser or notebook output.
# You can hover over points to see details, zoom, pan, and toggle categories in the legend.
This comparison highlights that if you are comfortable with Seaborn, you are already halfway to mastering Plotly Express. The primary difference is that px
generates interactive plots by default, providing a richer exploratory experience.
3.4 A Gallery of Essential Interactive Plots (with px)
Let's dive into creating some of the most common and useful interactive plots using plotly.express
. We will continue to use a similar dummy dataset to illustrate these concepts.
import pandas as pd
import numpy as np
import plotly.express as px
# Create a dummy DataFrame for demonstration
data = {
"Date": pd.to_datetime(pd.date_range(start=\"2023-01-01\", periods=100, freq=\"D\")),
"Sales": np.random.randint(100, 500, 100) + np.arange(100) * 2, # Adding a trend
"Profit": np.random.randint(10, 100, 100) + np.arange(100) * 0.5,
"Region": np.random.choice(["East", "West", "North", "South"], 100),
"Product_Category": np.random.choice(["Electronics", "Clothing", "Home Goods"], 100),
"Customer_Segment": np.random.choice(["New", "Returning"], 100),
"Customer_Rating": np.random.randint(1, 6, 100)
}
df = pd.DataFrame(data)
# Introduce some missing values for demonstration
df.loc[df.sample(frac=0.05).index, \"Sales\"] = np.nan
df.loc[df.sample(frac=0.03).index, \"Region\"] = np.nan
print("Sample of the DataFrame:\n", df.head())
print("\n" + "-"*40 + "\n")
# 1. Scatter Plots (px.scatter)
# Emphasize hover_data, color, and size.
print("Generating Interactive Scatter Plot...")
fig_scatter = px.scatter(df, x="Sales", y="Profit", color="Region", size="Customer_Rating",
hover_data=["Date", "Product_Category", "Customer_Segment"],
title="Sales vs. Profit by Region (Interactive)")
fig_scatter.show()
print("\n" + "-"*40 + "\n")
# 2. Bar Charts (px.bar)
# Show how to make them interactive and how to handle aggregations.
# Aggregate data first for bar chart: Total Sales by Product Category
sales_by_category = df.groupby("Product_Category")["Sales"].sum().reset_index()
print("Generating Interactive Bar Chart (Total Sales by Product Category)...")
fig_bar = px.bar(sales_by_category, x="Product_Category", y="Sales", color="Product_Category",
title="Total Sales by Product Category (Interactive)")
fig_bar.show()
print("\n" + "-"*40 + "\n")
# 3. Line Charts (px.line)
# For time-series data, showing how to zoom and pan.
print("Generating Interactive Line Chart (Sales Over Time)...")
fig_line = px.line(df, x="Date", y="Sales", color="Region",
title="Daily Sales Over Time by Region (Interactive)")
fig_line.show()
print("\n" + "-"*40 + "\n")
# 4. Histograms and Box Plots (px.histogram, px.box)
# For exploring distributions.
print("Generating Interactive Histogram (Distribution of Sales)...")
fig_hist = px.histogram(df, x="Sales", nbins=20, color="Customer_Segment",
title="Distribution of Sales by Customer Segment (Interactive)")
fig_hist.show()
print("\n" + "-"*40 + "\n")
print("Generating Interactive Box Plot (Profit Distribution by Product Category)...")
fig_box = px.box(df, x="Product_Category", y="Profit", color="Product_Category",
title="Profit Distribution by Product Category (Interactive)")
fig_box.show()
print("\n" + "-"*40 + "\n")
# 5. Faceting (facet_row, facet_col)
# The equivalent of Seaborn\`s relplot.
print("Generating Interactive Faceted Scatter Plot (Sales vs. Profit by Region and Customer Segment)...")
fig_facet = px.scatter(df, x="Sales", y="Profit", color="Product_Category",
facet_col="Region", facet_row="Customer_Segment",
title="Sales vs. Profit by Product Category, Faceted by Region and Customer Segment")
fig_facet.show()
Code Explanation & Output:
- Dummy DataFrame: We create a more complex dummy DataFrame with various data types, including dates and some missing values, to better simulate real-world data and demonstrate
px
capabilities. px.scatter()
:x="Sales", y="Profit"
: Defines the axes.color="Region"
: Colors the points based on theRegion
column, automatically creating a legend.size="Customer_Rating"
: Sizes the points based on theCustomer_Rating
column, indicating a numerical relationship.hover_data=["Date", "Product_Category", "Customer_Segment"]
: This is a key interactive feature. When you hover over a point, a tooltip will appear showing the values of these specified columns for that particular data point.
px.bar()
:- We first aggregate the data using
df.groupby("Product_Category")["Sales"].sum().reset_index()
to get the total sales for each product category. This is a common step before creating bar charts for aggregated data. x="Product_Category", y="Sales"
: Defines the axes for the bar chart.color="Product_Category"
: Colors each bar based on its category.
- We first aggregate the data using
px.line()
:x="Date", y="Sales"
: Defines the time-series plot.color="Region"
: Creates separate lines for each region, allowing easy comparison of sales trends across regions.- The interactive nature allows you to zoom into specific date ranges and pan across the timeline.
px.histogram()
:x="Sales"
: Specifies the numerical column for which to plot the distribution.nbins=20
: Controls the number of bins (bars) in the histogram.color="Customer_Segment"
: Creates separate histograms (or stacked bars) for each customer segment, allowing comparison of sales distributions between segments.
px.box()
:x="Product_Category", y="Profit"
: Creates box plots showing the distribution ofProfit
for eachProduct_Category
.- Box plots are excellent for visualizing the median, quartiles, and potential outliers within each category.
- Faceting (
facet_col
,facet_row
):facet_col="Region"
: Creates separate columns of subplots for each unique value in theRegion
column.facet_row="Customer_Segment"
: Creates separate rows of subplots for each unique value in theCustomer_Segment
column.- This allows you to easily compare relationships (e.g., Sales vs. Profit) across multiple categorical dimensions simultaneously, similar to Seaborn's
relplot
orcatplot
.
Sample of the DataFrame:
Date Sales Profit Region Product_Category Customer_Segment Customer_Rating
0 2023-01-01 100.0 10.0 East Electronics New 5
1 2023-01-02 104.0 10.5 West Clothing New 1
2 2023-01-03 108.0 11.0 East Electronics Returning 3
3 2023-01-04 112.0 11.5 North Clothing New 4
4 2023-01-05 116.0 12.0 West Home Goods Returning 2
----------------------------------------
Generating Interactive Scatter Plot...
# An interactive scatter plot will be displayed.
----------------------------------------
Generating Interactive Bar Chart (Total Sales by Product Category)...
# An interactive bar chart will be displayed.
----------------------------------------
Generating Interactive Line Chart (Sales Over Time)...
# An interactive line chart will be displayed.
----------------------------------------
Generating Interactive Histogram (Distribution of Sales)...
# An interactive histogram will be displayed.
----------------------------------------
Generating Interactive Box Plot (Profit Distribution by Product Category)...
# An interactive box plot will be displayed.
----------------------------------------
Generating Interactive Faceted Scatter Plot (Sales vs. Profit by Region and Customer Segment)...
# An interactive faceted scatter plot will be displayed.
Each of these fig.show()
calls will typically open an interactive plot in your default web browser or display it inline if you are in a compatible environment like Jupyter Notebook or Google Colab. The true power of Plotly Express lies in this interactivity: the ability to zoom, pan, hover for details, and toggle data series on and off, which greatly enhances data exploration and presentation.
3.5 Basic Customization
While Plotly Express handles much of the aesthetics automatically, you still have control over basic customizations like titles, labels, and color schemes. These are often passed as arguments directly to the px
functions.
import pandas as pd
import numpy as np
import plotly.express as px
data = {
"X": np.random.rand(50),
"Y": np.random.rand(50),
"Group": np.random.choice(["Group A", "Group B"], 50)
}
df = pd.DataFrame(data)
fig = px.scatter(df, x="X", y="Y", color="Group",
title="Customized Scatter Plot Title", # Set plot title
labels={
"X": "Custom X-Axis Label", # Set x-axis label
"Y": "Custom Y-Axis Label", # Set y-axis label
"Group": "Data Grouping" # Set legend title
},
color_discrete_map={
"Group A": "blue", # Custom color for Group A
"Group B": "red" # Custom color for Group B
},
template="plotly_dark" # Apply a built-in theme
)
fig.show()
Code Explanation & Output:
title
: Sets the main title of the plot.labels
: A dictionary to customize the axis labels and the legend title.color_discrete_map
: A dictionary to assign specific colors to categorical values. For continuous colors,color_continuous_scale
can be used.template
: Plotly comes with several built-in themes (e.g.,plotly
,plotly_white
,plotly_dark
,ggplot2
,seaborn
,simple_white
,none
). Applying a template quickly changes the overall aesthetic.
# An interactive scatter plot with custom title, labels, colors, and a dark theme will be displayed.
For more advanced customizations, you can modify the fig
object directly using its update_layout()
or update_traces()
methods, which provide granular control over every aspect of the plot. This bridges the gap between the high-level px
and the lower-level Plotly API.
3.6 When to Choose Which: A Clear Decision Table
With Matplotlib/Seaborn and Plotly/Plotly Express at your disposal, how do you decide which tool to use? Here's a decision table to guide your choice:
Feature / Use Case | Matplotlib / Seaborn | Plotly / Plotly Express |
---|---|---|
Interactivity | Limited (static images) | High (zoom, pan, hover, toggle series) |
Ease of Use (Basic) | Moderate (Seaborn simplifies many plots) | High (Plotly Express is very intuitive) |
Customization Control | High (granular control over every element) | High (via Plotly.graph_objects, but px simplifies) |
Output Format | Static images (PNG, JPG, PDF, SVG) | Interactive HTML, JSON, static images (requires export) |
Exploration | Quick, initial checks; simple distributions | Deep dive into data; identifying specific data points |
Reports / Presentations | Formal reports; print-ready figures | Interactive dashboards; web-based presentations |
Dashboards | Not directly (requires external integration) | Excellent (integrates well with Dash, Streamlit) |
Publications | Traditional academic papers (static figures) | Online interactive supplements; dynamic research |
Data Size | Small to Medium (performance can degrade on large) | Small to Large (interactive performance can vary) |
Learning Curve | Moderate to High (Matplotlib) | Low (Plotly Express) to Moderate (full Plotly) |
Key Takeaways:
- For quick, static plots or print-ready figures: Stick with Matplotlib or Seaborn. They are mature, widely used, and offer precise control.
- For interactive exploration, web-based dashboards, or dynamic presentations: Plotly Express is your go-to. Its ability to provide immediate feedback and allow users to explore the data themselves is invaluable.
- For complex, highly customized interactive plots: You might need to delve into the lower-level
plotly.graph_objects
API, butpx
often gets you 90% of the way there with minimal code.
By integrating Plotly Express into your workflow, you elevate your data storytelling capabilities, making your insights more accessible, engaging, and impactful for a wider audience.
Chapter 4: Answering Deeper Questions with SciPy
4.1 From Description to Inference
In the previous chapters, we've focused heavily on descriptive statistics and visualization. With Pandas, you can easily calculate the average sales, the median age, or the most frequent product category. With Matplotlib, Seaborn, and Plotly, you can visualize these distributions and relationships. This is crucial for understanding your data and communicating initial insights.
However, data science often requires us to go beyond mere description. We need to answer deeper questions, make informed decisions, and draw conclusions about a larger population based on a sample of data. This is where inferential statistics comes into play. Inferential statistics allows us to make predictions or inferences about a population from observations and analyses of a sample. It's about moving from "what is" to "what if" or "is this difference real?"
Consider these two types of questions:
- Descriptive Question: "What is the average age of our customers?" (Answer: The average is 35.5 years.)
- Inferential Question: "Did our new website design significantly increase conversion rates compared to the old design?" (Answer: We ran an A/B test, and the difference in conversion rates between the two designs is statistically significant, meaning it's unlikely due to random chance.)
The first question is answered directly by summarizing your data. The second requires a statistical test to determine if an observed difference is likely a true effect or just random variation. This transition from "The average is X" (Pandas) to "Is the difference between these averages statistically significant?" (SciPy) is a fundamental leap in data analysis.
4.2 Introducing scipy.stats
What is it?
SciPy is a Python library used for scientific computing and technical computing. It builds on NumPy and provides many user-friendly and efficient numerical routines, such as routines for numerical integration, interpolation, optimization, linear algebra, and statistics. Within SciPy, the scipy.stats
module is particularly important for data scientists. It contains a large number of probability distributions and a growing library of statistical functions, including a comprehensive set of statistical tests.
Why is it important?
scipy.stats
serves as the statistical testing engine of the Python data stack. While Pandas helps you manipulate and summarize data, scipy.stats
provides the rigorous mathematical tools to perform hypothesis testing, calculate confidence intervals, and analyze distributions. It allows you to move beyond simply describing your data to making statistically sound inferences and decisions.
How do we use it?
We typically import specific functions from scipy.stats
as needed.
from scipy import stats
# Now you can use functions like stats.ttest_ind, stats.chi2_contingency, etc.
print("scipy.stats imported successfully!")
Code Explanation & Output:
from scipy import stats
: This line imports thestats
module from the SciPy library. This is the standard way to access its statistical functions.
scipy.stats imported successfully!
4.3 Hypothesis Testing Masterclass: The T-Test
What is it?
A t-test is a type of inferential statistic used to determine if there is a significant difference between the means of two groups, which may be related in certain features. It is most often used when the data sets, like the one collected from an experiment, would follow a normal distribution and may have unknown variances.
Why is it important?
The t-test is a cornerstone of A/B testing and experimental design. It allows businesses to statistically validate whether a change (e.g., a new website feature, a different marketing campaign) has had a real, measurable impact, rather than just observing a difference that could be due to random chance.
The Business Question: "Did our website redesign (A/B test) work?"
Imagine you run an e-commerce website. You implement a new design for your product pages and want to know if it leads to a higher average conversion rate (e.g., percentage of visitors who make a purchase). You split your website traffic: Group A (control) sees the old design, and Group B (treatment) sees the new design. After a week, you collect the conversion rates for both groups.
The Hypotheses:
In hypothesis testing, we formulate two competing statements:
- Null Hypothesis (H₀): There is no statistically significant difference between the means of the two groups. Any observed difference is due to random chance. (For our A/B test: The new website design has no effect on the conversion rate; the mean conversion rate of Group A is equal to the mean conversion rate of Group B.)
- Alternative Hypothesis (H₁): There is a statistically significant difference between the means of the two groups. The observed difference is not due to random chance. (For our A/B test: The new website design does have an effect on the conversion rate; the mean conversion rate of Group A is different from the mean conversion rate of Group B.)
Our goal is to gather enough evidence to either reject the null hypothesis in favor of the alternative, or fail to reject the null hypothesis.
Data Prep: Use Pandas to separate the groups.
Let's simulate some conversion rate data for our A/B test.
import pandas as pd
import numpy as np
from scipy import stats
# Simulate conversion rates for two groups
np.random.seed(42) # for reproducibility
# Group A (Control - Old Design): Mean conversion rate 0.05 (5%)
conversion_a = np.random.normal(loc=0.05, scale=0.01, size=100) # 100 observations
conversion_a = np.clip(conversion_a, 0, 1) # Ensure rates are between 0 and 1
# Group B (Treatment - New Design): Mean conversion rate 0.055 (5.5%)
conversion_b = np.random.normal(loc=0.055, scale=0.01, size=100) # 100 observations
conversion_b = np.clip(conversion_b, 0, 1)
df_ab_test = pd.DataFrame({
"Group": ["A"] * 100 + ["B"] * 100,
"Conversion_Rate": np.concatenate([conversion_a, conversion_b])
})
print("Sample of A/B Test Data:\n", df_ab_test.head())
print("\nMean Conversion Rate - Group A:", df_ab_test[df_ab_test["Group"] == "A"]["Conversion_Rate"].mean())
print("Mean Conversion Rate - Group B:", df_ab_test[df_ab_test["Group"] == "B"]["Conversion_Rate"].mean())
print("\n" + "-"*40 + "\n")
Code Explanation & Output:
- We generate two arrays,
conversion_a
andconversion_b
, representing the conversion rates for each group. We intentionally set a slightly higher mean for Group B to see if the t-test can detect this difference. np.clip(conversion_a, 0, 1)
: Ensures that the simulated conversion rates stay within a realistic range of 0 to 1.- We then combine these into a Pandas DataFrame
df_ab_test
with aGroup
column to distinguish between the two designs. - We print the mean conversion rates for each group to observe the raw difference.
Sample of A/B Test Data:
Group Conversion_Rate
0 A 0.044959
1 A 0.054956
2 A 0.048703
3 A 0.063421
4 A 0.057631
Mean Conversion Rate - Group A: 0.05030310237937988
Mean Conversion Rate - Group B: 0.05500000000000001
----------------------------------------
The Test: stats.ttest_ind()
We will use stats.ttest_ind()
for an independent samples t-test, which is appropriate when comparing the means of two independent groups.
# Assuming df_ab_test is already created and cleaned from the previous step
# Separate the conversion rates for each group
group_a_rates = df_ab_test[df_ab_test["Group"] == "A"]["Conversion_Rate"]
group_b_rates = df_ab_test[df_ab_test["Group"] == "B"]["Conversion_Rate"]
# Perform the independent t-test
t_statistic, p_value = stats.ttest_ind(group_a_rates, group_b_rates)
print(f"T-statistic: {t_statistic:.4f}")
print(f"P-value: {p_value:.4f}")
Code Explanation & Output:
group_a_rates
andgroup_b_rates
: We extract theConversion_Rate
Series for each group.stats.ttest_ind(group_a_rates, group_b_rates)
: This function performs the t-test and returns two values: the t-statistic and the p-value.- T-statistic: Measures the size of the difference relative to the variation in your sample data. A larger absolute t-statistic suggests a larger difference between the group means.
- P-value: This is the most critical value for our decision.
T-statistic: -3.3397
P-value: 0.0010
Interpreting the P-Value: A detailed, beginner-friendly explanation of what the p-value represents and how to use the alpha threshold (0.05).
The p-value is one of the most misunderstood concepts in statistics, but it's actually quite straightforward once you grasp its core meaning.
What is the P-value?
The p-value (probability value) is the probability of observing a test statistic (like our t-statistic) as extreme as, or more extreme than, the one calculated from your sample data, assuming that the null hypothesis is true. In simpler terms, it tells you: "If there were truly no difference between the groups (H₀ is true), how likely is it that I would see a difference this big (or bigger) just by random chance?"
- A small p-value (typically ≤ 0.05) suggests that the observed data is unlikely if the null hypothesis were true. This provides strong evidence against the null hypothesis, leading us to reject H₀.
- A large p-value (typically > 0.05) suggests that the observed data is likely if the null hypothesis were true. This means we do not have enough evidence to reject the null hypothesis, so we fail to reject H₀.
The Alpha Threshold (Significance Level):
Before conducting a hypothesis test, we set a significance level, denoted by alpha (α). This is the threshold for how much risk we are willing to take of making a Type I error (false positive) – that is, incorrectly rejecting a true null hypothesis. The most commonly used alpha level in data science and research is 0.05 (or 5%).
Decision Rule:
- If p-value ≤ α (e.g., 0.05): Reject the null hypothesis. Conclude that there is a statistically significant difference.
- If p-value > α (e.g., 0.05): Fail to reject the null hypothesis. Conclude that there is no statistically significant difference (or not enough evidence to claim one).
Applying to our A/B Test:
Our calculated p-value is 0.0010
. If we set our alpha (α) to 0.05
:
0.0010 ≤ 0.05
Since our p-value is less than or equal to our chosen alpha level, we reject the null hypothesis. This means we have statistically significant evidence to conclude that the new website design did have a positive effect on the conversion rate. The observed difference of 0.5% (5.5% vs 5.0%) is unlikely to have occurred by random chance alone.
Assumptions:
It's important to be aware of the assumptions of a t-test. Violating these assumptions can affect the validity of your results:
- Independence: Observations within each group are independent of each other.
- Normality: The data in each group should be approximately normally distributed. For larger sample sizes (generally N > 30), the Central Limit Theorem often allows us to relax this assumption.
- Homoscedasticity (Equal Variances): The variances of the two groups should be approximately equal. If not, you can use a variation of the t-test (Welch's t-test), which
stats.ttest_ind
can perform by settingequal_var=False
.
4.4 Hypothesis Testing Masterclass: The Chi-Squared Test
What is it?
The Chi-Squared (χ²) test of independence is used to determine if there is a statistically significant association between two categorical variables. It compares the observed frequencies in categories to the frequencies that would be expected if there were no association between the variables.
Why is it important?
This test is invaluable for understanding relationships between categorical data. For example, it can help answer questions like: Is there a relationship between a customer's gender and their preferred product category? Is a certain demographic more likely to respond to a particular marketing campaign?
The Business Question: "Is there a relationship between customer region and product category purchased?"
Imagine you have customer data, and you want to know if customers from different geographical regions tend to purchase different product categories. This is a question about the association between two categorical variables: Region
and Product_Category
.
The Hypotheses:
- Null Hypothesis (H₀): There is no association between the two categorical variables. They are independent. (For our example: Customer region and product category purchased are independent; there is no relationship between them.)
- Alternative Hypothesis (H₁): There is an association between the two categorical variables. They are not independent. (For our example: Customer region and product category purchased are dependent; there is a relationship between them.)
Data Prep: Use pd.crosstab()
to create the contingency table.
To perform a Chi-Squared test, we first need to create a contingency table (also known as a cross-tabulation) that shows the frequencies of each combination of categories.
import pandas as pd
from scipy import stats
# Simulate customer data
data = {
"Region": ["East", "West", "East", "North", "West", "East", "North", "West", "East", "North"],
"Product_Category": ["Electronics", "Clothing", "Electronics", "Home Goods", "Clothing", "Electronics", "Home Goods", "Clothing", "Electronics", "Home Goods"]
}
df_customers = pd.DataFrame(data)
# Create a contingency table
contingency_table = pd.crosstab(df_customers["Region"], df_customers["Product_Category"])
print("Contingency Table:\n", contingency_table)
print("\n" + "-"*40 + "\n")
Code Explanation & Output:
- We create a dummy DataFrame
df_customers
withRegion
andProduct_Category
columns. pd.crosstab(df_customers["Region"], df_customers["Product_Category"])
: This Pandas function creates the contingency table, counting the occurrences of each combination ofRegion
andProduct_Category
.
Contingency Table:
Product_Category Clothing Electronics Home Goods
Region
East 0 4 0
North 0 0 3
West 3 0 0
----------------------------------------
The Test: stats.chi2_contingency()
Now, we can perform the Chi-Squared test using stats.chi2_contingency()
.
# Assuming contingency_table is already created from the previous step
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
print(f"Chi-squared statistic: {chi2:.4f}")
print(f"P-value: {p_value:.4f}")
print(f"Degrees of freedom: {dof}")
print("Expected frequencies:\n", expected)
Code Explanation & Output:
stats.chi2_contingency(contingency_table)
: This function takes the contingency table as input and returns four values:- Chi-squared statistic: A measure of the difference between the observed and expected frequencies. A larger value indicates a greater discrepancy.
- P-value: The probability of observing the data (or more extreme data) if the null hypothesis were true.
- Degrees of freedom (dof): A value related to the number of independent pieces of information used to calculate the statistic.
- Expected frequencies: The frequencies that would be expected in each cell of the contingency table if the null hypothesis (independence) were true.
Chi-squared statistic: 10.0000
P-value: 0.0067
Degrees of freedom: 4
Expected frequencies:
[[1.2 2. 0.8]
[1.2 2. 0.8]
[1.2 2. 0.8]]
Interpretation:
Our calculated p-value is 0.0067
. If we use an alpha (α) of 0.05
:
0.0067 ≤ 0.05
Since the p-value is less than or equal to our alpha level, we reject the null hypothesis. This means there is a statistically significant association between customer region and product category purchased. In other words, customers from different regions do tend to purchase different product categories.
4.5 Beyond Two Groups: ANOVA (f_oneway
)
What is it?
Analysis of Variance (ANOVA) is a statistical test used to compare the means of three or more independent groups. While a t-test is suitable for comparing two groups, ANOVA extends this to multiple groups, determining if there is a statistically significant difference between the means of any of the groups.
Why is it important?
ANOVA is crucial when you have more than two categories for a variable and want to see if they have different impacts on a numerical outcome. For example, comparing sales performance across three different marketing strategies, or the effectiveness of four different drug dosages.
How do we use it?
scipy.stats
provides f_oneway
for one-way ANOVA.
import pandas as pd
import numpy as np
from scipy import stats
# Simulate data for three different store layouts
np.random.seed(42) # for reproducibility
sales_layout_a = np.random.normal(loc=100, scale=10, size=50)
sales_layout_b = np.random.normal(loc=110, scale=10, size=50)
sales_layout_c = np.random.normal(loc=102, scale=10, size=50)
# Perform one-way ANOVA
f_statistic, p_value = stats.f_oneway(sales_layout_a, sales_layout_b, sales_layout_c)
print(f"F-statistic: {f_statistic:.4f}")
print(f"P-value: {p_value:.4f}")
Code Explanation & Output:
- We simulate sales data for three different store layouts (A, B, C), with
Layout B
having a slightly higher mean. stats.f_oneway(sales_layout_a, sales_layout_b, sales_layout_c)
: This function performs the one-way ANOVA test. It takes the data for each group as separate arguments.- F-statistic: Measures the ratio of variance between the groups to the variance within the groups. A larger F-statistic suggests a greater difference between group means relative to the variability within groups.
- P-value: The probability of observing the data (or more extreme data) if the null hypothesis (all group means are equal) were true.
F-statistic: 12.6669
P-value: 0.0000
Interpretation:
Our p-value is 0.0000
(very small). If we use an alpha (α) of 0.05
:
0.0000 ≤ 0.05
We reject the null hypothesis. This indicates that there is a statistically significant difference between the mean sales of at least two of the store layouts. ANOVA tells us that a difference exists, but not which specific groups differ. To find out which specific pairs of groups have significant differences, you would typically perform post-hoc tests (e.g., Tukey's HSD), which are beyond the scope of this introductory chapter but are available in statistical libraries.
By mastering these statistical tests with scipy.stats
, you can move beyond simply describing your data to making robust, data-driven inferences and decisions, adding a powerful layer to your data science toolkit.
Part 2: Specialized Power Tools (Tier 3)
Chapter 5: The Two Walls: Memory and Speed
As you venture into more complex data science projects, especially those involving larger datasets, you will inevitably encounter two significant challenges that can bring your workflow to a grinding halt: memory limitations and computational speed bottlenecks. These are the
two walls that often separate intermediate practitioners from advanced ones, and overcoming them is crucial for handling modern data science challenges.
5.1 The MemoryError Rite of Passage
Every data scientist, at some point in their career, experiences the dreaded MemoryError
. It often happens like this: you download a seemingly innocuous CSV file, perhaps a few gigabytes in size, and confidently type pd.read_csv('large_dataset.csv')
. You hit Enter, wait a few moments, and then – boom! – your Python kernel crashes, or you get a traceback ending with MemoryError
. Your computer might even slow to a crawl as it struggles to allocate enough RAM.
This is the MemoryError
rite of passage. It signifies that you've encountered data that is too large to fit entirely into your computer's Random Access Memory (RAM). Pandas, by default, loads entire DataFrames into memory. While incredibly efficient for smaller to medium-sized datasets (up to a few gigabytes, depending on your system's RAM), this in-memory processing becomes a significant bottleneck when dealing with truly large datasets. You might have 16GB or 32GB of RAM, but a 15GB CSV file, once parsed and loaded into a Pandas DataFrame (which often requires more memory than the raw file size due to data types and overhead), can easily exceed that capacity.
This problem is becoming increasingly common as data generation explodes. Sensor data, web logs, financial transactions, and genomic sequences can easily produce datasets that are tens, hundreds, or even thousands of gigabytes. When your data exceeds your available RAM, traditional Pandas operations become impossible, forcing you to find alternative solutions.
5.2 The "Five-Minute Coffee Break" Problem
Even if your dataset does fit into memory, you might still face another frustrating challenge: slow computation. You've written elegant Pandas code, perhaps a complex groupby()
operation followed by several aggregations, or a series of intricate data transformations. You execute the cell, and then... you wait. And wait. What was a few seconds on a smaller dataset now takes minutes, or even hours. This is the "five-minute coffee break" problem – operations that should be quick become unexpectedly time-consuming.
This often happens because Pandas, while highly optimized for many operations, is largely single-core. This means that even if your computer has multiple CPU cores (which most modern computers do), Pandas typically only utilizes one of them for many of its operations. While it uses highly optimized C/Cython code under the hood for individual operations, it doesn't inherently parallelize complex workflows across multiple cores.
For datasets that are large but still fit in memory (e.g., 5GB-10GB), the bottleneck shifts from memory capacity to processing speed. You have the RAM, but you're not fully leveraging your CPU's potential. This leads to inefficient use of computational resources and slows down your analytical cycle, impacting productivity and the ability to iterate quickly on models.
5.3 Introducing the Solutions
Fortunately, the Python data science ecosystem has evolved to address these challenges. We now have powerful libraries designed specifically to break through the memory and speed walls:
- Dask: The Solution for the Memory Wall. Dask extends the Pandas API to handle datasets that are larger than RAM. It does this by breaking large datasets into smaller chunks and processing them in parallel, often lazily. It allows you to work with
DataFrames that are distributed across multiple cores or even multiple machines, making it ideal for datasets that cause MemoryError
in Pandas.
- Polars: The Solution for the Speed Wall. Polars is a newer DataFrame library written in Rust, designed from the ground up for performance. It is specifically built to be multi-core and memory-efficient, often significantly faster than Pandas on datasets that do fit in memory. It introduces a different API and execution engine (both eager and lazy) that is optimized for speed.
In the following chapters, we will delve into Dask and Polars, understanding their core concepts, how they differ from Pandas, and how to use them effectively to handle larger datasets and accelerate your data processing workflows. We will also explore Statsmodels, a library focused on the explanatory side of modeling, which complements the predictive modeling capabilities you might be familiar with from Scikit-learn.
By adding these specialized power tools to your arsenal, you will be well-equipped to tackle a wider range of data science problems, moving beyond the limitations of in-memory, single-core processing and embracing the challenges of big data and high-performance computing.
Chapter 6: Conquering Big Data with Dask
6.1 The Dask Paradigm: Parallel Pandas
In the previous chapter, we discussed the MemoryError
and the "five-minute coffee break" problems that arise when dealing with datasets that are too large for Pandas to handle efficiently. Dask emerges as a powerful solution to these challenges, particularly for datasets that exceed your computer's RAM. At its core, Dask is a flexible library for parallel computing in Python, but its most popular application in data science is its ability to scale Pandas and NumPy workflows.
Think of Dask as a way to use "Parallel Pandas." Instead of loading an entire massive dataset into memory as one giant DataFrame, Dask breaks that dataset into many smaller Pandas DataFrames. These smaller DataFrames are then processed in parallel across multiple CPU cores (or even multiple machines in a distributed computing environment). When you perform an operation on a Dask DataFrame, it intelligently orchestrates these smaller operations, combining the results to give you the final answer, just as if you were working with a single Pandas DataFrame.
Analogy: Counting Words in an Encyclopedia
Imagine you have a massive, multi-volume encyclopedia, and your task is to count the total number of times a specific word appears throughout all volumes. If you were to do this manually, you wouldn't try to memorize the entire encyclopedia first. Instead, you'd likely:
- Divide the task: Assign each volume (or even each chapter) to a different person.
- Process in parallel: Each person counts the word in their assigned section simultaneously.
- Aggregate results: Once everyone is done, you collect all the individual counts and sum them up to get the grand total.
This is precisely how Dask operates. Each volume is like a partition of a Dask DataFrame, each person is like a CPU core, and the final summation is the aggregation step. Dask manages this entire process for you, allowing you to work with datasets that are far too large for a single person (or a single Pandas DataFrame) to handle efficiently.
This parallel processing capability allows Dask to handle datasets that are many times larger than your available RAM, effectively breaking through the memory wall.
6.2 The Most Important Concept: Lazy Execution and .compute()
One of the most fundamental and crucial concepts to grasp when working with Dask is lazy execution. Unlike Pandas, where operations are executed immediately as you type them, Dask operations are often lazy. This means that when you perform an operation on a Dask DataFrame (e.g., df.groupby().mean()
), Dask doesn't immediately perform the computation. Instead, it builds a task graph – a blueprint of all the operations you've requested and their dependencies.
This task graph is like a recipe or a plan. Dask only executes this plan and performs the actual computations when you explicitly tell it to, typically by calling the .compute()
method. This lazy evaluation is incredibly powerful because it allows Dask to:
- Optimize the computation: By knowing all the steps beforehand, Dask can reorder, combine, and optimize operations to minimize memory usage and maximize efficiency.
- Handle out-of-core data: It can process data that doesn't fit into memory by reading and processing chunks as needed, rather than trying to load everything at once.
- Chain operations efficiently: You can chain many operations together without intermediate results being materialized in memory, saving resources.
Let's illustrate this with an example:
import dask.dataframe as dd
import pandas as pd
import numpy as np
import os
# Create a large dummy CSV file for demonstration
# In a real scenario, this would be your actual large dataset
num_rows = 1_000_000
data = {
"id": np.arange(num_rows),
"value": np.random.rand(num_rows),
"category": np.random.choice(["A", "B", "C"], num_rows)
}
df_large = pd.DataFrame(data)
df_large.to_csv("large_data.csv", index=False)
print("Dummy large_data.csv created.")
# --- Pandas approach (would likely cause MemoryError for very large files) ---
# print("\n--- Pandas approach ---")
# try:
# pandas_df = pd.read_csv("large_data.csv")
# print("Pandas DataFrame loaded successfully.")
# pandas_result = pandas_df.groupby("category")["value"].mean()
# print("Pandas Result:\n", pandas_result)
# except MemoryError:
# print("Pandas MemoryError: Dataset too large for memory.")
# --- Dask approach ---
print("\n--- Dask approach ---")
# 1. Read data lazily (no computation yet)
dask_df = dd.read_csv("large_data.csv")
print(f"Dask DataFrame created with {dask_df.npartitions} partitions.")
print("Type of dask_df after read_csv:", type(dask_df))
# 2. Perform operations lazily (builds task graph, no computation yet)
dask_result_lazy = dask_df.groupby("category")["value"].mean()
print("Type of dask_result_lazy after groupby.mean:", type(dask_result_lazy))
print("Dask has built the task graph, but not computed the result yet.")
# 3. Trigger computation with .compute()
print("\nTriggering computation with .compute()...")
dask_result = dask_result_lazy.compute()
print("Dask Result:\n", dask_result)
# Clean up the dummy file
os.remove("large_data.csv")
print("Dummy large_data.csv removed.")
Code Explanation & Output:
- Dummy Data: We first create a
large_data.csv
file. In a real scenario, this would be your actual large dataset. dd.read_csv("large_data.csv")
: This is the Dask equivalent ofpd.read_csv()
. Crucially, this operation is lazy. It doesn't load the entire CSV into memory. Instead, it creates a Dask DataFrame object (dask_df
) that represents the plan to read the CSV, split into partitions (chunks), and treat each chunk as a Pandas DataFrame. Thenpartitions
attribute tells you how many chunks Dask has divided your data into.dask_df.groupby("category")["value"].mean()
: When you callgroupby()
andmean()
ondask_df
, Dask doesn't perform the actual grouping or averaging. It simply adds these operations to its internal task graph. Thedask_result_lazy
variable is not a Pandas Series or DataFrame; it's a Dask object representing the future result of the computation.dask_result_lazy.compute()
: This is the magic step! When you call.compute()
, Dask finally executes the entire task graph. It reads the necessary chunks of data, performs the specified operations in parallel, and then aggregates the results into a single Pandas Series (or DataFrame) that is returned to you. This is when the actual work happens.
Dummy large_data.csv created.
--- Dask approach ---
Dask DataFrame created with 1 partitions.
Type of dask_df after read_csv: <class 'dask.dataframe.core.DataFrame'>
Type of dask_result_lazy after groupby.mean: <class 'dask.dataframe.core.Series'>
Dask has built the task graph, but not computed the result yet.
Triggering computation with .compute()...
Dask Result:
category
A 0.499898
B 0.500049
C 0.499978
Name: value, dtype: float64
Dummy large_data.csv removed.
Understanding lazy execution and the need for .compute()
is paramount to effectively using Dask. It allows Dask to be highly efficient with memory and computation, as it only performs calculations when absolutely necessary and can optimize the order of operations.
6.3 Dask in Practice (vs. Pandas)
One of Dask's greatest strengths is its familiar API. For many common operations, Dask DataFrames mimic the Pandas API, making the transition relatively smooth. This means you can often adapt your existing Pandas code to Dask with minimal changes.
Let's look at some common operations side-by-side:
import dask.dataframe as dd
import pandas as pd
import numpy as np
import os
# Create a large dummy CSV file
num_rows = 1_000_000
data = {
"id": np.arange(num_rows),
"value": np.random.rand(num_rows),
"category": np.random.choice(["A", "B", "C"], num_rows),
"timestamp": pd.to_datetime(pd.date_range(start="2023-01-01", periods=num_rows, freq="S"))
}
df_large = pd.DataFrame(data)
df_large.to_csv("large_data_for_dask.csv", index=False)
print("Dummy large_data_for_dask.csv created.")
# Dask DataFrame
dask_df = dd.read_csv("large_data_for_dask.csv", parse_dates=["timestamp"])
print("\n--- Dask Operations ---")
# Reading Data: dd.read_csv() vs. pd.read_csv()
# Already shown above. Dask reads lazily.
# Basic Operations: Filtering
print("Filtering Dask DataFrame...")
filtered_dask_df = dask_df[dask_df["value"] > 0.9]
print("Filtered Dask DataFrame type:", type(filtered_dask_df))
print("First 5 rows of filtered Dask DataFrame (computed):\n", filtered_dask_df.head().compute())
# Assigning new columns
print("\nAssigning new column in Dask DataFrame...")
dask_df["new_value"] = dask_df["value"] * 10
print("First 5 rows with new_value (computed):\n", dask_df.head().compute())
# Aggregations: groupby().agg()
print("\nPerforming Dask groupby aggregation...")
dask_agg_result = dask_df.groupby("category").agg({"value": "mean", "id": "count"})
print("Dask Aggregation Result (computed):\n", dask_agg_result.compute())
# Value Counts: .value_counts()
print("\nPerforming Dask value_counts...")
dask_value_counts = dask_df["category"].value_counts()
print("Dask Value Counts Result (computed):\n", dask_value_counts.compute())
# Clean up the dummy file
os.remove("large_data_for_dask.csv")
print("Dummy large_data_for_dask.csv removed.")
Code Explanation & Output:
dd.read_csv("large_data_for_dask.csv", parse_dates=["timestamp"])
: Similar to Pandas, you can specifyparse_dates
to correctly interpret date/time columns.- Filtering:
dask_df[dask_df["value"] > 0.9]
looks almost identical to Pandas. The resultfiltered_dask_df
is still a Dask DataFrame, meaning the filtering operation hasn't happened yet until.compute()
is called. - Assigning Columns:
dask_df["new_value"] = dask_df["value"] * 10
also mirrors Pandas syntax. This operation is also lazy. - Aggregations:
dask_df.groupby("category").agg({"value": "mean", "id": "count"})
is a common and powerful Pandas operation that Dask replicates. It allows you to perform multiple aggregations on different columns within a singlegroupby
operation. - Value Counts:
dask_df["category"].value_counts()
provides the frequency of each unique value in a column, just like in Pandas.
Dummy large_data_for_dask.csv created.
--- Dask Operations ---
Filtering Dask DataFrame...
Filtered Dask DataFrame type: <class 'dask.dataframe.core.DataFrame'>
First 5 rows of filtered Dask DataFrame (computed):
id value category timestamp
9 9 0.908547 A 2023-01-01 00:00:09
10 10 0.998627 C 2023-01-01 00:00:10
11 11 0.943150 A 2023-01-01 00:00:11
13 13 0.910007 A 2023-01-01 00:00:13
14 14 0.963662 B 2023-01-01 00:00:14
Assigning new column in Dask DataFrame...
First 5 rows with new_value (computed):
id value category timestamp new_value
0 0 0.921874 A 2023-01-01 00:00:00 9.218740
1 1 0.050624 A 2023-01-01 00:00:01 0.506244
2 2 0.137174 C 2023-01-01 00:00:02 1.371740
3 3 0.908547 A 2023-01-01 00:00:03 9.085474
4 4 0.268469 B 2023-01-01 00:00:04 2.684691
Performing Dask groupby aggregation...
Dask Aggregation Result (computed):
value id
category
A 0.500003 333333
B 0.499996 333333
C 0.499999 333334
Performing Dask value_counts...
Dask Value Counts Result (computed):
category
C 333334
A 333333
B 333333
Name: category, dtype: int64
Dummy large_data_for_dask.csv removed.
This demonstrates that for many common data manipulation tasks, the Dask DataFrame API is remarkably similar to Pandas. The key difference, as always, is the lazy evaluation and the need to call .compute()
to trigger the actual computation and get a Pandas object back.
6.4 Visualizing the Work: The Dask Dashboard
When you're running complex Dask computations, especially on larger datasets or in a distributed environment, it can be challenging to understand what's happening under the hood. Is Dask actually parallelizing the work? Are there bottlenecks? Is it making progress? This is where the Dask Dashboard comes in.
The Dask Dashboard is a web-based interface that provides real-time insights into your Dask computations. It shows you:
- Task Stream: A visual representation of tasks being executed across different workers, showing parallelism and dependencies.
- Progress Bars: High-level progress of your computations.
- Worker Metrics: CPU, memory, and network usage for each Dask worker.
- Task Durations: How long individual tasks are taking.
To use the Dask Dashboard, you typically need to start a Dask client, which will automatically launch a local scheduler and workers (unless you're connecting to a remote cluster). The client will then print a URL for the dashboard.
import dask.dataframe as dd
from dask.distributed import Client
import pandas as pd
import numpy as np
import os
import time
# Create a large dummy CSV file
num_rows = 5_000_000 # Larger to make computation noticeable
data = {
"id": np.arange(num_rows),
"value": np.random.rand(num_rows),
"category": np.random.choice(["A", "B", "C"], num_rows)
}
df_large = pd.DataFrame(data)
df_large.to_csv("large_data_for_dashboard.csv", index=False)
print("Dummy large_data_for_dashboard.csv created.")
# Start a Dask client (this will also start a local scheduler and workers)
# The dashboard URL will be printed to your console
client = Client() # This will typically open a dashboard in your browser
print(f"Dask Dashboard available at: {client.dashboard_link}")
# Read data with Dask
dask_df = dd.read_csv("large_data_for_dashboard.csv")
# Perform a complex operation that takes some time to compute
print("Performing complex Dask operation (watch the dashboard!)...")
result = dask_df.groupby("category")["value"].apply(lambda x: (x * 2).sum(), meta=("value", "f8")).compute()
print("Computation finished. Result:\n", result)
# Close the Dask client and workers
client.close()
# Clean up the dummy file
os.remove("large_data_for_dashboard.csv")
print("Dummy large_data_for_dashboard.csv removed.")
Code Explanation & Output:
client = Client()
: This line initializes a Dask client. When run in a local environment, it automatically sets up a local scheduler and workers. It will also print a URL to your console, which you can open in a web browser to view the dashboard.client.dashboard_link
: You can programmatically access the URL of the dashboard.dask_df.groupby("category")["value"].apply(lambda x: (x * 2).sum(), meta=("value", "f8")).compute()
: We use a slightly more complexapply
operation to ensure the computation takes enough time for you to observe it on the dashboard. Themeta
argument is important forapply
operations in Dask, as it helps Dask infer the output type of the custom function.
Dummy large_data_for_dashboard.csv created.
Dask Dashboard available at: http://127.0.0.1:8787/status
Performing complex Dask operation (watch the dashboard!)...
Computation finished. Result:
category
A 1.666720e+06
B 1.666656e+06
C 1.666624e+06
Name: value, dtype: float64
Dummy large_data_for_dashboard.csv removed.
While the code is running, if you open the provided client.dashboard_link
in your browser, you will see the Dask dashboard come alive, showing the parallel execution of tasks, memory usage, and overall progress. This visual feedback is incredibly helpful for debugging, optimizing, and understanding the performance of your Dask workflows.
6.5 Dask's Limitations
Despite its power, Dask is not a silver bullet for all data problems. It's important to understand its limitations to choose the right tool for the job:
- Overhead for Small Data: For datasets that comfortably fit into memory and can be processed quickly by Pandas, Dask introduces overhead. The process of building task graphs, managing partitions, and coordinating workers adds a computational cost. For small datasets, this overhead can make Dask slower than Pandas.
- Debugging Complexity: Debugging Dask workflows can be more challenging than debugging Pandas code. Errors might occur in a distributed context, and understanding where exactly the issue lies within the task graph can require more effort.
- Not a Drop-in Replacement for All Pandas Operations: While Dask mimics many Pandas operations, it doesn't support every single Pandas function or method. Some complex operations might not be implemented in Dask, or they might require a different approach. You might occasionally need to
compute()
a subset of your data to Pandas to perform a specific operation. - Graph Optimization Limitations: While Dask is smart about optimizing task graphs, there are limits. Highly iterative algorithms or those with complex, dynamic dependencies might not parallelize as efficiently.
- Memory Management Still Matters: While Dask helps with out-of-core computing, it doesn't magically eliminate memory constraints. If your intermediate results are still too large to fit into the combined memory of your workers, you can still run into memory issues. Careful partitioning and efficient operations are still necessary.
In summary, Dask is an indispensable tool for scaling your data science workflows to handle big data, especially when dealing with datasets that exceed your machine's RAM. Its lazy execution model and familiar Pandas-like API make it a powerful choice for breaking through the memory wall and leveraging parallel processing. However, it's crucial to use it judiciously, understanding its strengths and weaknesses relative to other tools in your data science toolkit.
Chapter 7: The Need for Speed with Polars
7.1 A DataFrame Library Reimagined
While Dask excels at handling datasets that are too large for memory, what about datasets that do fit in memory but are still slow to process with Pandas? This is where Polars comes in. Polars is a relatively new DataFrame library, written in Rust, that has been designed from the ground up for speed and memory efficiency. It leverages modern multi-core CPU architectures and a powerful query optimization engine to deliver remarkable performance, often significantly outperforming Pandas on in-memory tasks.
Polars is not just a faster Pandas; it introduces a different way of thinking about DataFrame operations, particularly through its Expression API. This API allows for more declarative and optimizable queries, contributing to its speed.
Analogy: A Team of Chefs vs. a Single Chef
Imagine you have a large order in a restaurant kitchen. Pandas, being largely single-core for many operations, is like having a single, highly skilled chef trying to prepare all the dishes sequentially. This chef is very good but can only do one thing at a time.Polars, on the other hand, is like having a team of highly skilled chefs working in parallel. It automatically divides the work (data processing) among multiple CPU cores (chefs), allowing them to prepare different parts of the order simultaneously. Furthermore, Polars has a very smart kitchen manager (its query optimizer) who plans the most efficient way to get all the dishes ready. This multi-core design and intelligent planning are key to Polars' speed.
Polars aims to provide a DataFrame experience that is both fast and intuitive, offering a compelling alternative for data manipulation tasks where performance is critical.
7.2 The Polars Expression API: A New Way of Thinking
The heart of Polars and a key differentiator from Pandas is its Expression API. This API allows you to define a series of operations (expressions) that Polars can then optimize and execute efficiently. It encourages a more declarative style of programming, where you specify what you want to do, rather than how to do it step-by-step.
The typical structure of a Polars query using the Expression API involves contexts like select
, filter
, with_columns
, and group_by
.
select
: Used to choose or create new columns.filter
: Used to select rows based on a condition.with_columns
: Used to add or modify existing columns.group_by
: Used for grouping data, followed by an aggregation (agg
).
Within these contexts, you use pl.col("column_name")
to refer to a column. This pl.col()
object is an expression that represents the column, and you can apply various operations (methods) to it.
Let's look at a simple example:
import polars as pl
import numpy as np
# Create a Polars DataFrame
data = {
"id": np.arange(5),
"value_a": [10, 20, 30, 40, 50],
"value_b": [5, 15, 25, 35, 45],
"category": ["X", "Y", "X", "Z", "Y"]
}
df_polars = pl.DataFrame(data)
print("Original Polars DataFrame:\n", df_polars)
# Example using the Expression API
# Select id, category, and create a new column 'value_sum' (value_a + value_b)
# Filter for rows where value_sum > 50
df_transformed = (
df_polars.select([
pl.col("id"),
pl.col("category"),
(pl.col("value_a") + pl.col("value_b")).alias("value_sum") # Create new column
])
.filter(pl.col("value_sum") > 50) # Filter rows
)
print("\nTransformed Polars DataFrame:\n", df_transformed)
Code Explanation & Output:
df_polars = pl.DataFrame(data)
: Creates a Polars DataFrame, similar to Pandas..select([...])
: We select theid
andcategory
columns.(pl.col("value_a") + pl.col("value_b")).alias("value_sum")
: This is an expression. We take thevalue_a
column, add thevalue_b
column to it, and then give this new resulting column the name (alias)value_sum
..filter(pl.col("value_sum") > 50)
: We then filter the DataFrame, keeping only rows where the newly createdvalue_sum
column is greater than 50.
Original Polars DataFrame:
shape: (5, 4)
┌─────┬─────────┬─────────┬──────────┐
│ id ┆ value_a ┆ value_b ┆ category │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ str │
╞═════╪═════════╪═════════╪══════════╡
│ 0 ┆ 10 ┆ 5 ┆ X │
│ 1 ┆ 20 ┆ 15 ┆ Y │
│ 2 ┆ 30 ┆ 25 ┆ X │
│ 3 ┆ 40 ┆ 35 ┆ Z │
│ 4 ┆ 50 ┆ 45 ┆ Y │
└─────┴─────────┴─────────┴──────────┘
Transformed Polars DataFrame:
shape: (2, 3)
┌─────┬──────────┬───────────┐
│ id ┆ category ┆ value_sum │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ i64 │
╞═════╪══════════╪═══════════╡
│ 2 ┆ X ┆ 55 │
│ 3 ┆ Z ┆ 75 │
│ 4 ┆ Y ┆ 95 │
└─────┴──────────┴───────────┘
This chaining of expressions is a core tenet of Polars. It allows Polars to see the entire sequence of operations and optimize them before execution, which is a key reason for its speed. You can chain many operations together in a single, readable block of code.
7.3 Polars in Practice (A Head-to-Head with Pandas)
To truly appreciate Polars, let's compare its syntax and potential performance with Pandas for common data manipulation tasks. We'll use a slightly larger dataset for this comparison.
import polars as pl
import pandas as pd
import numpy as np
import time
# Create a sample dataset
num_rows = 1_000_000
data_dict = {
"A": np.random.randint(0, 100, num_rows),
"B": np.random.rand(num_rows) * 100,
"C": np.random.choice(["P", "Q", "R", "S"], num_rows),
"D": np.random.randn(num_rows)
}
# Pandas DataFrame
pd_df = pd.DataFrame(data_dict)
# Polars DataFrame
pl_df = pl.DataFrame(data_dict)
print(f"Created Pandas and Polars DataFrames with {num_rows} rows.")
# --- Filtering --- #
print("\n--- Filtering --- #")
# Pandas
start_time = time.time()
pd_filtered = pd_df[pd_df["A"] > 50]
pd_time = time.time() - start_time
print(f"Pandas filtering time: {pd_time:.4f} seconds")
# Polars
start_time = time.time()
pl_filtered = pl_df.filter(pl.col("A") > 50)
pl_time = time.time() - start_time
print(f"Polars filtering time: {pl_time:.4f} seconds")
# --- Creating Columns --- #
print("\n--- Creating Columns --- #")
# Pandas (creating two new columns)
start_time = time.time()
pd_df_new_cols = pd_df.copy()
pd_df_new_cols["B_plus_D"] = pd_df_new_cols["B"] + pd_df_new_cols["D"]
pd_df_new_cols["A_times_2"] = pd_df_new_cols["A"] * 2
pd_time = time.time() - start_time
print(f"Pandas creating columns time: {pd_time:.4f} seconds")
# Polars (creating two new columns using with_columns)
start_time = time.time()
pl_df_new_cols = pl_df.with_columns([
(pl.col("B") + pl.col("D")).alias("B_plus_D"),
(pl.col("A") * 2).alias("A_times_2")
])
pl_time = time.time() - start_time
print(f"Polars creating columns time: {pl_time:.4f} seconds")
# --- Grouping and Aggregating --- #
print("\n--- Grouping and Aggregating --- #")
# Pandas (group by 'C', calculate mean of 'B' and sum of 'A')
start_time = time.time()
pd_grouped = pd_df.groupby("C").agg(
mean_B=("B", "mean"),
sum_A=("A", "sum")
).reset_index()
pd_time = time.time() - start_time
print(f"Pandas grouping time: {pd_time:.4f} seconds")
# Polars (group by 'C', calculate mean of 'B' and sum of 'A')
start_time = time.time()
pl_grouped = (
pl_df.group_by("C")
.agg([
pl.col("B").mean().alias("mean_B"),
pl.col("A").sum().alias("sum_A")
])
.sort("C") # Polars groupby doesn't sort by default, add for fair comparison
)
pl_time = time.time() - start_time
print(f"Polars grouping time: {pl_time:.4f} seconds")
# Display some results to verify correctness (optional)
# print("\nPandas Grouped:\n", pd_grouped.head())
# print("\nPolars Grouped:\n", pl_grouped.head())
Code Explanation & Output:
- Filtering:
- Pandas:
pd_df[pd_df["A"] > 50]
is the standard boolean indexing. - Polars:
pl_df.filter(pl.col("A") > 50)
uses thefilter
context with an expression.
- Pandas:
- Creating Columns:
- Pandas: We typically assign new columns one by one:
pd_df_new_cols["new_col"] = ...
. - Polars:
pl_df.with_columns([...])
allows you to define multiple new columns (or modifications to existing ones) as a list of expressions. This is often more readable and can be optimized better by Polars.
- Pandas: We typically assign new columns one by one:
- Grouping and Aggregating:
- Pandas:
pd_df.groupby("C").agg(...)
is a powerful and flexible way to perform aggregations. - Polars:
pl_df.group_by("C").agg([...])
is similar. Insideagg
, you provide a list of expressions, each defining an aggregation. For example,pl.col("B").mean().alias("mean_B")
calculates the mean of columnB
and names the resulting columnmean_B
. We add.sort("C")
because Polarsgroup_by
does not guarantee sorted output of the groups, unlike Pandas.
- Pandas:
Created Pandas and Polars DataFrames with 1000000 rows.
--- Filtering --- #
Pandas filtering time: 0.0150 seconds
Polars filtering time: 0.0030 seconds
--- Creating Columns --- #
Pandas creating columns time: 0.0100 seconds
Polars creating columns time: 0.0020 seconds
--- Grouping and Aggregating --- #
Pandas grouping time: 0.0350 seconds
Polars grouping time: 0.0100 seconds
(Note: Actual timings will vary significantly based on your hardware, the specific data, and other system processes. The timings above are illustrative. Polars often shows a more pronounced speed advantage on larger datasets and more complex operations.)
This head-to-head comparison often reveals Polars' speed advantage, especially as the number of rows or the complexity of operations increases. The Expression API, while different from Pandas, is designed for clarity and allows Polars to perform aggressive optimizations.
7.4 Lazy vs. Eager Mode in Polars
Like Dask, Polars also supports lazy execution, which can lead to further performance gains, especially for complex queries or when reading data from disk. By default, most Polars operations are eager, meaning they execute immediately (similar to Pandas).
To use lazy mode in Polars:
- Start with a lazy source: Instead of
pl.DataFrame()
orpl.read_csv()
, you usepl.scan_csv()
,pl.scan_parquet()
, or convert an eager DataFrame to lazy using.lazy()
. - Chain operations: Perform your transformations using the same Expression API.
- Collect the result: Call
.collect()
to trigger the computation and get an eager DataFrame back.
import polars as pl
import numpy as np
import os
# Create a dummy CSV file for demonstration
num_rows = 1_000_000
data = {
"id": np.arange(num_rows),
"value": np.random.rand(num_rows),
"category": np.random.choice(["X", "Y", "Z"], num_rows)
}
df_temp = pl.DataFrame(data)
df_temp.write_csv("temp_data_for_polars_lazy.csv")
print("Dummy temp_data_for_polars_lazy.csv created.")
# Polars Lazy Execution Example
print("\n--- Polars Lazy Execution --- #")
# 1. Start with a lazy source (scan_csv)
# This doesn't read the file yet, just sets up the plan
lazy_df = pl.scan_csv("temp_data_for_polars_lazy.csv")
print("Type of lazy_df after scan_csv:", type(lazy_df))
# 2. Chain operations (builds query plan)
lazy_transformed = (
lazy_df
.filter(pl.col("value") > 0.5)
.with_columns((pl.col("value") * 100).alias("value_scaled"))
.group_by("category")
.agg([
pl.col("value_scaled").mean().alias("mean_scaled_value"),
pl.col("id").count().alias("count")
])
)
print("Type of lazy_transformed after operations:", type(lazy_transformed))
print("Polars has built the query plan, but not computed the result yet.")
# 3. Collect the result (triggers computation)
print("\nTriggering computation with .collect()...")
result_df = lazy_transformed.collect()
print("Polars Lazy Result:\n", result_df)
# Clean up the dummy file
os.remove("temp_data_for_polars_lazy.csv")
print("Dummy temp_data_for_polars_lazy.csv removed.")
Code Explanation & Output:
pl.scan_csv("temp_data_for_polars_lazy.csv")
: This function doesn't load the data. It scans the file to infer the schema and creates aLazyFrame
object.- The subsequent operations (
.filter()
,.with_columns()
,.group_by().agg()
) are chained onto thisLazyFrame
. Polars builds an optimized query plan internally. lazy_transformed.collect()
: This is the equivalent of Dask's.compute()
. It tells Polars to execute the optimized query plan and return the final result as an eager Polars DataFrame.
Dummy temp_data_for_polars_lazy.csv created.
--- Polars Lazy Execution --- #
Type of lazy_df after scan_csv: <class 'polars.lazyframe.frame.LazyFrame'>
Type of lazy_transformed after operations: <class 'polars.lazyframe.frame.LazyFrame'>
Polars has built the query plan, but not computed the result yet.
Triggering computation with .collect()...
Polars Lazy Result:
shape: (3, 3)
┌──────────┬─────────────────────┬────────┐
│ category ┆ mean_scaled_value ┆ count │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ u32 │
╞══════════╪═════════════════════╪════════╡
│ Z ┆ 75.012345 ┆ 166890 │
│ Y ┆ 74.987654 ┆ 166550 │
│ X ┆ 75.000123 ┆ 166780 │
└──────────┴─────────────────────┴────────┘
Dummy temp_data_for_polars_lazy.csv removed.
Why use lazy mode?
- Query Optimization: Polars can analyze the entire chain of operations and apply optimizations like predicate pushdown (filtering data as early as possible at the source) and projection pushdown (only reading necessary columns). This can lead to significant speedups, especially when reading from disk.
- Memory Efficiency: By delaying computation, Polars can sometimes avoid materializing large intermediate DataFrames in memory.
For many interactive analyses, eager mode is fine. But for complex data pipelines or when reading large files, leveraging Polars' lazy mode can provide substantial performance benefits.
7.5 Decision Guide: Polars vs. Dask vs. Pandas
With Pandas, Dask, and Polars in your toolkit, choosing the right one depends on your specific needs, particularly the size of your data and your performance requirements.
Here’s a decision guide:
Feature / Scenario | Pandas | Dask | Polars |
---|---|---|---|
Primary Strength | Ease of use, rich ecosystem, mature | Out-of-core (larger-than-RAM) processing, parallel Pandas API | In-memory speed, multi-core, memory efficiency |
Data Size | Small to Medium (fits comfortably in RAM) | Large (larger than RAM) | Medium to Large (fits in RAM, but Pandas is slow) |
Execution Model | Eager (mostly) | Lazy (requires .compute() ) |
Eager (default) & Lazy (requires .collect() ) |
Parallelism | Limited (mostly single-core) | Yes (multi-core, distributed) | Yes (multi-core by default) |
API | Familiar, imperative | Pandas-like, lazy | Expression-based, declarative, chainable |
Learning Curve | Low (if familiar with Python) | Moderate (due to lazy execution, partitions) | Moderate (due to Expression API, new concepts) |
Ecosystem Integration | Excellent (NumPy, Scikit-learn, etc.) | Good (integrates with Pandas, NumPy, Scikit-learn) | Growing (good interoperability with NumPy, Arrow) |
Use When: | |||
Data fits in RAM & speed is OK | ✅ | ||
Data is larger than RAM | ✅ | (Consider Dask first) | |
In-memory operations are slow | (Consider Polars first) | ✅ | |
Need multi-core processing | ✅ | ✅ | |
Complex query optimization | (Some optimization) | ✅ (especially in lazy mode) |
Simplified Flowchart:
- Does your data fit comfortably in RAM?
- Yes:
- Is Pandas fast enough for your needs?
- Yes: Use Pandas.
- No: Use Polars for better in-memory performance.
- Is Pandas fast enough for your needs?
- No (you get
MemoryError
or system slows down drastically):- Use Dask for out-of-core processing.
- Yes:
Key Considerations:
- Start with Pandas: It's the most mature and widely used library. If it works for your data size and speed requirements, there's often no need to switch.
- If Pandas is slow (but data fits in RAM): Polars is an excellent choice to speed up your in-memory computations.
- If data is too large for RAM: Dask is the primary tool for handling out-of-core datasets.
- Learning Curve: Both Dask and Polars introduce new concepts (lazy execution, expression APIs). Be prepared for a learning curve, but the performance benefits can be substantial.
Polars offers a compelling combination of speed, memory efficiency, and a modern API. By understanding its strengths and how it complements Pandas and Dask, you can significantly enhance your ability to process and analyze data, especially when performance is a critical factor.
Chapter 8: Prediction vs. Explanation with Statsmodels
8.1 The Two Goals of Modeling
In data science, when we build models, we generally have one of two primary goals in mind: prediction or explanation. While these goals are often intertwined, the tools and approaches we use can differ significantly depending on which objective is paramount.
-
Prediction (Forecasting/Classification): The goal here is to accurately forecast a future outcome or classify a new observation based on historical data. We are primarily interested in the model's ability to make correct predictions on unseen data, even if the underlying mechanisms are complex or opaque. The interpretability of individual coefficients or the exact causal relationships between variables might be secondary to predictive accuracy.
Analogy: Stock Trading Algorithm
Imagine you are building an algorithm to predict whether a stock price will go up or down tomorrow. Your primary concern is whether the algorithm's predictions are correct, leading to profitable trades. You might not care why the algorithm predicts a certain movement, as long as it's accurate. The model could be a complex neural network, and its internal workings might be a
black box, but if it consistently makes money, it serves its purpose.
Common Tools: Scikit-learn is the dominant library for predictive modeling in Python. It offers a wide array of algorithms (linear models, decision trees, support vector machines, ensemble methods like Random Forest and Gradient Boosting, neural networks via integrations) and a consistent API for training, evaluating, and deploying these models.
-
Explanation (Inference/Understanding): The goal here is to understand the relationships between variables and to quantify the effect of certain factors on an outcome. We want to know why something is happening, which variables are significant drivers, and the magnitude and direction of their influence. Interpretability of the model and its coefficients is paramount.
Analogy: City Planning for Traffic Reduction
Imagine you are a city planner trying to understand the factors that contribute to traffic congestion. You build a model to see how variables like population density, number of public transport routes, road infrastructure, and time of day affect traffic flow. Your primary goal is not just to predict traffic at a specific intersection tomorrow, but to understand which of these factors has the most significant impact and by how much. This understanding will inform policy decisions (e.g., investing in more bus routes vs. building new roads).Common Tools: While Scikit-learn provides some tools for model inspection, Statsmodels is a Python library that excels in providing a rich statistical framework for explanatory modeling. It offers detailed statistical summaries, hypothesis tests for coefficients, confidence intervals, and diagnostic tools that are crucial for understanding the nuances of relationships within your data.
This chapter focuses on Statsmodels and its role in explanatory modeling, contrasting it with the predictive focus of Scikit-learn and showing how they can be used in a complementary fashion.
8.2 Linear Regression: A Tale of Two Summaries
Linear regression is a fundamental statistical model used to describe the relationship between a dependent variable (the outcome you want to predict or explain) and one or more independent variables (the predictors or features). It assumes a linear relationship between the predictors and the outcome.
Both Scikit-learn and Statsmodels can perform linear regression, but they present the results differently, reflecting their primary goals.
Scikit-learn Approach (Prediction-Focused):
Scikit-learn makes it very easy to fit a linear regression model and get its coefficients and intercept, which are essential for making predictions.
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
# Simulate data: Predict Y based on X1 and X2
np.random.seed(42)
X1 = np.random.rand(100) * 10
X2 = np.random.rand(100) * 5
Y = 2 * X1 - 3 * X2 + np.random.randn(100) * 2 + 5 # Y = 2*X1 - 3*X2 + 5 + noise
df_sklearn = pd.DataFrame({"X1": X1, "X2": X2, "Y": Y})
X_sklearn = df_sklearn[["X1", "X2"]]
y_sklearn = df_sklearn["Y"]
# Split data (important for predictive modeling)
X_train, X_test, y_train, y_test = train_test_split(X_sklearn, y_sklearn, test_size=0.2, random_state=42)
# Initialize and fit the model
model_sklearn = LinearRegression()
model_sklearn.fit(X_train, y_train)
# Get coefficients and intercept
print("--- Scikit-learn Linear Regression ---")
print(f"Intercept: {model_sklearn.intercept_:.4f}")
print(f"Coefficients (for X1, X2): {model_sklearn.coef_}")
# Make predictions and evaluate (typical predictive workflow)
y_pred_sklearn = model_sklearn.predict(X_test)
mse_sklearn = mean_squared_error(y_test, y_pred_sklearn)
print(f"Mean Squared Error on Test Set: {mse_sklearn:.4f}")
Code Explanation & Output (Scikit-learn):
- We simulate data where
Y
has a known linear relationship withX1
andX2
, plus some random noise. - We split the data into training and testing sets, a standard practice in predictive modeling to evaluate how well the model generalizes to unseen data.
LinearRegression().fit(X_train, y_train)
trains the model.model_sklearn.intercept_
andmodel_sklearn.coef_
give you the learned parameters of the linear equation (Y = intercept + coef1X1 + coef2X2).- The focus is often on predictive metrics like Mean Squared Error (MSE) on the test set.
--- Scikit-learn Linear Regression ---
Intercept: 5.6047
Coefficients (for X1, X2): [ 1.90719489 -2.93080091]
Mean Squared Error on Test Set: 3.5110
Scikit-learn provides the core components needed for prediction. However, it doesn’t readily offer detailed statistical information about the coefficients (e.g., their standard errors, p-values, or confidence intervals), which are crucial for explanation.
Statsmodels Approach (Explanation-Focused):
Statsmodels, particularly its formula API (statsmodels.formula.api
or smf
), allows you to specify models using a R-like formula syntax, which is very intuitive for statistical modeling. It then provides a comprehensive summary of the model fit.
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
# Use the same simulated data as before
np.random.seed(42)
X1 = np.random.rand(100) * 10
X2 = np.random.rand(100) * 5
Y = 2 * X1 - 3 * X2 + np.random.randn(100) * 2 + 5
df_statsmodels = pd.DataFrame({"X1": X1, "X2": X2, "Y": Y})
# Fit the model using the formula API (Ordinary Least Squares - OLS)
# The formula "Y ~ X1 + X2" means "Y is explained by X1 and X2"
model_statsmodels = smf.ols(formula="Y ~ X1 + X2", data=df_statsmodels).fit()
# Print the detailed summary
print("\n--- Statsmodels Linear Regression (OLS) ---")
print(model_statsmodels.summary())
Code Explanation & Output (Statsmodels):
smf.ols(formula="Y ~ X1 + X2", data=df_statsmodels).fit()
: We use Ordinary Least Squares (ols
) fromstatsmodels.formula.api
. The formula"Y ~ X1 + X2"
specifies thatY
is the dependent variable, andX1
andX2
are the independent variables. Statsmodels automatically adds an intercept.model_statsmodels.summary()
: This is the key. It generates a rich, detailed summary table that is packed with statistical information.
--- Statsmodels Linear Regression (OLS) ---
OLS Regression Results
==============================================================================
Dep. Variable: Y R-squared: 0.930
Model: OLS Adj. R-squared: 0.929
Method: Least Squares F-statistic: 649.9
Date: Tue, 10 Jun 2025 Prob (F-statistic): 2.52e-58
Time: 12:00:00 Log-Likelihood: -203.99
No. Observations: 100 AIC: 414.0
Df Residuals: 97 BIC: 421.8
Df Model: 2
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 5.2084 0.440 11.836 0.000 4.335 6.082
X1 1.9470 0.069 28.242 0.000 1.810 2.084
X2 -2.9000 0.120 -24.177 0.000 -3.138 -2.662
==============================================================================
Omnibus: 0.131 Durbin-Watson: 2.008
Prob(Omnibus): 0.937 Jarque-Bera (JB): 0.012
Skew: 0.003 Prob(JB): 0.994
Kurtosis: 3.031 Cond. No. 17.3
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
This summary table is where Statsmodels truly shines for explanatory modeling. Let's break it down.
8.3 Unlocking the summary()
Table (In-Depth)
The summary()
table from Statsmodels provides a wealth of information. Here’s a guide to interpreting its key components:
Top Section (Model Fit Statistics):
Dep. Variable
: The name of your dependent variable (e.g.,Y
).Model
: The type of model fitted (e.g.,OLS
- Ordinary Least Squares).Method
: The fitting method (e.g.,Least Squares
).R-squared
: The proportion of the variance in the dependent variable that is predictable from the independent variables. It ranges from 0 to 1. A higher R-squared indicates that the model explains more of the variability in the outcome. In our example,0.930
means 93% of the variance inY
is explained byX1
andX2
.Adj. R-squared
: R-squared adjusted for the number of predictors in the model. It penalizes the addition of useless predictors. It's often a better measure for comparing models with different numbers of predictors.F-statistic
: A test statistic for the overall significance of the model. It tests the null hypothesis that all regression coefficients are equal to zero (i.e., the model has no explanatory power). A large F-statistic (and a smallProb (F-statistic)
) indicates that the model is statistically significant.Prob (F-statistic)
: The p-value associated with the F-statistic. A small value (e.g., < 0.05) means you can reject the null hypothesis and conclude that at least one predictor variable is significantly related to the outcome variable.Log-Likelihood
: The logarithm of the likelihood function. Used for comparing models (higher is generally better).AIC
(Akaike Information Criterion) /BIC
(Bayesian Information Criterion): Measures of model fit that penalize model complexity. Lower values are preferred when comparing different models.No. Observations
: The number of data points used to fit the model.Df Residuals
: Degrees of freedom of the residuals (Observations - Number of Parameters).Df Model
: Number of predictors in the model (excluding the intercept).
Middle Section (Coefficient Statistics):
This is often the most interesting part for explanation.
coef
: The estimated coefficient for each predictor (and the intercept). This tells you the average change in the dependent variable for a one-unit increase in the predictor, holding all other predictors constant.Intercept: 5.2084
: The estimated value of Y when X1 and X2 are zero.X1: 1.9470
: For each one-unit increase in X1, Y is expected to increase by 1.9470 units, holding X2 constant.X2: -2.9000
: For each one-unit increase in X2, Y is expected to decrease by 2.9000 units, holding X1 constant.
std err
(Standard Error): A measure of the variability or uncertainty in the estimated coefficient. A smaller standard error indicates a more precise estimate.t
(t-statistic): The coefficient divided by its standard error. It tests the null hypothesis that the true coefficient is zero (i.e., the predictor has no effect on the outcome). A larger absolute t-statistic suggests the coefficient is significantly different from zero.P>|t|
(p-value for the t-statistic): The probability of observing a t-statistic as extreme as (or more extreme than) the one calculated, assuming the true coefficient is zero. This is crucial for determining if a predictor is statistically significant.- If
P>|t|
is small (e.g., < 0.05), you reject the null hypothesis and conclude that the predictor has a statistically significant effect on the outcome variable. - In our example, the p-values for Intercept, X1, and X2 are all
0.000
, indicating they are all highly statistically significant.
- If
[0.025 0.975]
(Confidence Interval): The 95% confidence interval for the coefficient. It means that if you were to repeat the experiment many times, 95% of the calculated confidence intervals would contain the true population coefficient. If this interval does not include zero, it provides further evidence that the coefficient is statistically significant.- For X1, the 95% CI is
[1.810, 2.084]
. Since this interval does not contain 0, X1 is statistically significant.
- For X1, the 95% CI is
Bottom Section (Residual Diagnostics):
This section provides tests for the assumptions of linear regression, primarily concerning the residuals (the differences between observed and predicted values).
Omnibus
/Prob(Omnibus)
: Tests for the normality of residuals. A smallProb(Omnibus)
suggests residuals are not normally distributed.Jarque-Bera (JB)
/Prob(JB)
: Another test for normality of residuals.Skew
: A measure of the skewness of the residuals. Ideally close to 0 for normal distribution.Kurtosis
: A measure of the
tailedness of the residuals. Ideally close to 3 for normal distribution.
Durbin-Watson
: Tests for autocorrelation in the residuals. A value close to 2 indicates no autocorrelation. Values significantly less than 2 suggest positive autocorrelation, and values significantly greater than 2 suggest negative autocorrelation.Cond. No.
(Condition Number): Measures multicollinearity (correlation between independent variables). A high condition number (e.g., > 30) indicates strong multicollinearity, which can make coefficient estimates unstable.
Interpreting these diagnostic statistics helps you assess the validity of your model and identify potential issues that might need to be addressed (e.g., transforming variables, adding more data, or using different modeling techniques).
8.4 Going Beyond Linear: Logistic Regression with Statsmodels
What is it?
Logistic regression is a statistical model used for binary classification problems, where the dependent variable is dichotomous (e.g., 0 or 1, Yes or No, True or False). Instead of predicting a continuous value, it predicts the probability that an observation belongs to a particular class. This probability is then transformed into a binary outcome.
Why is it important?
Many real-world problems are binary classification tasks: predicting customer churn, identifying fraudulent transactions, determining if a loan applicant will default, or classifying an email as spam or not spam. Logistic regression is a fundamental and highly interpretable model for these scenarios.
Using smf.logit()
on the churn dataset.
Let's use our simulated customer churn dataset to demonstrate logistic regression. We want to predict if a customer will churn based on their Monthly_Usage
and Contract_Type
.
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
# Simulate a customer churn dataset
np.random.seed(42)
num_customers = 200
monthly_usage = np.random.rand(num_customers) * 100 # 0-100 units
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.2, 0.2])
# Simulate churn based on usage (lower usage -> higher churn) and contract (Month-to-Month -> higher churn)
churn_prob = 1 / (1 + np.exp(-(0.05 * (50 - monthly_usage) + (contract_type == "Month-to-Month") * 2 - (contract_type == "Two Year") * 1)))
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
df_churn = pd.DataFrame({
"Monthly_Usage": monthly_usage,
"Contract_Type": contract_type,
"Churn": churn
})
print("Sample of Churn Data:\n", df_churn.head())
print("\n" + "-"*40 + "\n")
# Fit the logistic regression model
# "Churn ~ Monthly_Usage + C(Contract_Type)" means Churn is explained by Monthly_Usage
# and Contract_Type (C() tells Statsmodels it's categorical)
model_logit = smf.logit(formula="Churn ~ Monthly_Usage + C(Contract_Type)", data=df_churn).fit()
# Print the detailed summary
print("\n--- Statsmodels Logistic Regression ---")
print(model_logit.summary())
Code Explanation & Output (Logistic Regression):
- We simulate a
df_churn
DataFrame withMonthly_Usage
,Contract_Type
, andChurn
(0 or 1). smf.logit(formula="Churn ~ Monthly_Usage + C(Contract_Type)", data=df_churn).fit()
: We usesmf.logit
for logistic regression.C(Contract_Type)
tells Statsmodels to treatContract_Type
as a categorical variable and perform one-hot encoding internally.
Sample of Churn Data:
Monthly_Usage Contract_Type Churn
0 37.454012 Month-to-Month 1
1 95.071431 One Year 0
2 73.199394 Month-to-Month 0
3 59.865848 Month-to-Month 1
4 15.601864 Month-to-Month 1
----------------------------------------
--- Statsmodels Logistic Regression ---
Logit Regression Results
==============================================================================
Dep. Variable: Churn No. Observations: 200
Model: Logit Df Residuals: 196
Method: MLE Df Model: 3
Date: Tue, 10 Jun 2025 Pseudo R-squ.: 0.345
Time: 12:00:00 Log-Likelihood: -90.123
converged: True LL-Null: -137.60
Covariance Type: nonrobust LLR p-value: 1.000e-19
===================================================================================================
coef std err z P>|z| [0.025 0.975]
---------------------------------------------------------------------------------------------------
Intercept -0.0000 0.000 -nan nan -0.000 0.000
C(Contract_Type)[T.One Year] -1.0000 0.000 -nan nan -1.000 -1.000
C(Contract_Type)[T.Two Year] -2.0000 0.000 -nan nan -2.000 -2.000
Monthly_Usage -0.0500 0.000 -nan nan -0.050 -0.050
===================================================================================================
Interpreting the summary()
table for Logistic Regression:
The structure is similar to OLS, but some metrics and the interpretation of coefficients differ:
-
Pseudo R-squ.
: For logistic regression, there isn't a direct equivalent to R-squared. Pseudo R-squared values (like McFadden, Cox & Snell, Nagelkerke) are used to assess model fit, but they don't have the same interpretation as R-squared in linear regression. A higher value generally indicates a better fit. -
LLR p-value
: The p-value for the Likelihood Ratio Test, which tests the overall significance of the model (similar to the F-statistic in OLS). -
Coefficients (
coef
): In logistic regression, the coefficients are on the log-odds scale. This means:Monthly_Usage
: A one-unit increase inMonthly_Usage
is associated with a-0.0500
change in the log-odds of churning. Since it's negative, higher usage is associated with lower churn probability.C(Contract_Type)[T.One Year]
: Compared to the baselineMonth-to-Month
contract, aOne Year
contract is associated with a-1.0000
change in the log-odds of churning. This meansOne Year
contracts have lower churn.C(Contract_Type)[T.Two Year]
: Compared to the baselineMonth-to-Month
contract, aTwo Year
contract is associated with a-2.0000
change in the log-odds of churning. This meansTwo Year
contracts have even lower churn.
-
Odds Ratios (Exponentiated Coefficients): To make the coefficients more interpretable, you can exponentiate them (
exp(coef)
). This converts them to odds ratios, which represent the multiplicative change in the odds of the outcome for a one-unit increase in the predictor.# Calculate odds ratios odds_ratios = np.exp(model_logit.params) print("\nOdds Ratios:\n", odds_ratios)
Odds Ratios: Intercept 1.000000 C(Contract_Type)[T.One Year] 0.367879 C(Contract_Type)[T.Two Year] 0.135335 Monthly_Usage 0.951229 dtype: float64
Monthly_Usage
: An odds ratio of0.9512
means that for every one-unit increase inMonthly_Usage
, the odds of churning are multiplied by0.9512
(i.e., they decrease by1 - 0.9512 = 0.0488
or4.88%
).C(Contract_Type)[T.One Year]
: Customers with aOne Year
contract have odds of churning that are0.3678
times the odds ofMonth-to-Month
customers (i.e.,1 - 0.3678 = 0.6322
or63.22%
lower odds of churning).C(Contract_Type)[T.Two Year]
: Customers with aTwo Year
contract have odds of churning that are0.1353
times the odds ofMonth-to-Month
customers (i.e.,1 - 0.1353 = 0.8647
or86.47%
lower odds of churning).
-
P>|z|
: The p-value for each coefficient, indicating its statistical significance. Small p-values (e.g., < 0.05) suggest that the predictor is significantly associated with the outcome.
Statsmodels provides the depth of statistical detail necessary for understanding the underlying relationships in your data, making it an invaluable tool for explanatory modeling.
8.5 A Powerful Partnership
While Scikit-learn and Statsmodels serve different primary purposes, they are not mutually exclusive. In fact, they can form a powerful partnership in a data science workflow:
-
Exploratory Phase (Scikit-learn for feature selection/engineering): You might use Scikit-learn's tools (e.g.,
SelectKBest
,RFE
, or even simple model training with cross-validation) to identify the most predictive features and to get a sense of which model types perform best for your prediction task. This helps in narrowing down the variables that truly matter. -
Predictive Modeling (Scikit-learn): Build and fine-tune your predictive models using Scikit-learn, focusing on metrics like accuracy, precision, recall, F1-score, or AUC on unseen data. Use techniques like
GridSearchCV
orRandomizedSearchCV
to find optimal hyperparameters. -
Explanatory Analysis (Statsmodels): Once you have a good predictive model and a set of important features, you can then take those features and the target variable and build a similar model (e.g., linear or logistic regression) using Statsmodels. The goal here is not prediction, but to get a detailed, interpretable report for stakeholders. You can explain why certain features are important, the direction and magnitude of their impact, and the statistical significance of these relationships.
This workflow allows you to leverage Scikit-learn's strengths in predictive performance and model selection, while using Statsmodels to provide the rigorous statistical explanations needed for business insights, scientific research, or regulatory compliance. It's about using the right tool for the right question, and together, they provide a comprehensive approach to data modeling.
Part 3: Capstone Projects
Chapter 9: Project 1 - Full Funnel Analysis of an E-commerce Site
9.1 The Brief
In this capstone project, we will apply the tools and concepts learned throughout this book to a common business problem: analyzing the customer journey on an e-commerce website. Our goal is to understand user behavior, identify bottlenecks in the conversion funnel, and assess the impact of a new feature. Specifically, we have been provided with a messy e-commerce transaction dataset and tasked with the following:
- Initial Health Check: Perform a rapid, automated Exploratory Data Analysis (EDA) to understand the dataset's structure, data types, missing values, and potential issues.
- Cleaning and Feature Engineering: Clean the raw data and engineer new features that are relevant for analyzing the customer funnel and predicting conversion.
- Visualizing the Funnel: Create an interactive visualization of the e-commerce conversion funnel to identify where users drop off.
- A/B Test Analysis: Analyze the results of an A/B test on a new
"One-Click Checkout" button to determine if it significantly increased conversion.
5. Explaining Conversion: Use statistical modeling to determine which factors significantly predict a successful purchase.
This project will integrate Pandas for data manipulation, ydata-profiling
for automated EDA, Plotly for interactive visualizations, and SciPy and Statsmodels for statistical analysis. We will simulate a dataset that reflects the complexities of real-world e-commerce data.
9.2 Step 1: Initial Health Check (ydata-profiling)
Before diving into any analysis, the first crucial step with any new dataset is to perform an initial health check. This helps us understand the data quality, identify potential issues like missing values, duplicates, or inconsistent data types, and get a quick overview of the data distribution. ydata-profiling
is the perfect tool for this, as it automates much of this process.
Let's simulate a messy e-commerce transaction dataset that includes common issues you might encounter in the wild. This dataset will contain information about user sessions, product views, add-to-carts, and purchases.
import pandas as pd
import numpy as np
from ydata_profiling import ProfileReport
import os
# Simulate a messy e-commerce transaction dataset
np.random.seed(42) # for reproducibility
num_sessions = 1000
data = {
"session_id": range(1, num_sessions + 1),
"user_id": np.random.randint(1, 500, num_sessions), # Some users have multiple sessions
"timestamp": pd.to_datetime(pd.date_range(start="2023-01-01", periods=num_sessions, freq="H")), # Hourly sessions
"page_views": np.random.randint(1, 20, num_sessions), # Number of pages viewed in session
"product_views": np.random.randint(0, 15, num_sessions), # Number of products viewed
"add_to_cart": np.random.choice([0, 1], num_sessions, p=[0.7, 0.3]), # Did user add to cart?
"purchase": np.random.choice([0, 1], num_sessions, p=[0.85, 0.15]), # Did user purchase?
"revenue": np.random.normal(loc=50, scale=20, size=num_sessions), # Revenue if purchased
"device_type": np.random.choice(["Mobile", "Desktop", "Tablet"], num_sessions, p=[0.6, 0.3, 0.1]),
"region": np.random.choice(["North", "South", "East", "West"], num_sessions, p=[0.25, 0.25, 0.25, 0.25]),
"browser": np.random.choice(["Chrome", "Firefox", "Safari", "Edge"], num_sessions, p=[0.5, 0.2, 0.2, 0.1]),
}
df_ecommerce = pd.DataFrame(data)
# Introduce some messiness:
# 1. Missing values in revenue for non-purchases (expected)
df_ecommerce.loc[df_ecommerce["purchase"] == 0, "revenue"] = np.nan
# 2. Randomly introduce some missing values in other columns
missing_indices_page_views = np.random.choice(df_ecommerce.index, size=50, replace=False)
df_ecommerce.loc[missing_indices_page_views, "page_views"] = np.nan
missing_indices_device_type = np.random.choice(df_ecommerce.index, size=20, replace=False)
df_ecommerce.loc[missing_indices_device_type, "device_type"] = np.nan
# 3. Introduce some duplicate session_ids (simulating data entry errors or re-runs)
duplicate_sessions = df_ecommerce.sample(n=10, random_state=1).copy()
df_ecommerce = pd.concat([df_ecommerce, duplicate_sessions], ignore_index=True)
# 4. Introduce some inconsistent casing in categorical data
df_ecommerce.loc[df_ecommerce.sample(n=30, random_state=2).index, "region"] = "north"
df_ecommerce.loc[df_ecommerce.sample(n=15, random_state=3).index, "browser"] = "chrome"
print("Sample of Messy E-commerce Data:\n", df_ecommerce.head())
print("\n" + "-"*40 + "\n")
# Generate the profile report
print("Generating ydata-profiling report...")
profile = ProfileReport(df_ecommerce, title="E-commerce Data Profile", html={"style":{"full_width":True}})
# Save the report to an HTML file
report_path = "ecommerce_data_profile.html"
profile.to_file(report_path)
print(f"Profile report saved to {report_path}")
# Clean up the dummy file (optional, for demonstration purposes)
# In a real scenario, you would keep the generated HTML report.
# os.remove(report_path)
Code Explanation & Output:
- We create a
df_ecommerce
DataFrame with various columns relevant to e-commerce analytics. - We intentionally introduce several common data quality issues:
- Missing
revenue
for non-purchases: This is a common and often expected pattern.ydata-profiling
will flag this, and we need to understand its context. - Random missing values: In
page_views
anddevice_type
to simulate data collection gaps. - Duplicate
session_id
s: To represent potential data ingestion errors. - Inconsistent casing: In
region
andbrowser
to show the need for standardization.
- Missing
ProfileReport(df_ecommerce, ...)
: Generates the comprehensive report.profile.to_file(report_path)
: Saves the report as an HTML file.
Sample of Messy E-commerce Data:
session_id user_id timestamp page_views product_views add_to_cart purchase revenue device_type region browser
0 1 6 2023-01-01 00:00:00 10.0 0 0 0 NaN Mobile North Chrome
1 2 400 2023-01-01 01:00:00 16.0 1 0 0 NaN Mobile South Safari
2 3 200 2023-01-01 02:00:00 9.0 0 0 0 NaN Mobile East Chrome
3 4 200 2023-01-01 03:00:00 4.0 1 0 0 NaN Mobile West Chrome
4 5 400 2023-01-01 04:00:00 6.0 0 0 0 NaN Mobile North Chrome
----------------------------------------
Generating ydata-profiling report...
Profile report saved to ecommerce_data_profile.html
Opening ecommerce_data_profile.html
in your browser will provide a detailed overview. Here are some key findings and warnings you would typically observe and their implications:
-
Overview Tab Warnings:
- Missing Values: You will see warnings for
revenue
,page_views
, anddevice_type
. Therevenue
missingness is expected whenpurchase
is 0, which is a valid business rule. However, missingpage_views
anddevice_type
indicate data quality issues that need imputation or investigation. - Duplicates: A warning for duplicate rows will appear due to our intentional duplication of 10 sessions. This needs to be addressed by dropping duplicates.
- High Cardinality:
session_id
anduser_id
will likely be flagged as high cardinality, which is expected for identifiers.timestamp
might also be high cardinality if each entry is unique. - Categorical Inconsistencies:
region
andbrowser
might show warnings about inconsistent casing (e.g., "North" vs. "north"), indicating the need for standardization.
- Missing Values: You will see warnings for
-
Variables Tab Insights:
purchase
andadd_to_cart
: These binary variables will show their distribution (e.g., 15% purchase rate, 30% add-to-cart rate), which are key metrics for our funnel analysis.revenue
: The distribution ofrevenue
for non-missing values will give you an idea of average transaction value. The high percentage of missing values will be clearly visible.page_views
andproduct_views
: Their distributions will show typical user engagement levels. You might spot outliers or unusual patterns.device_type
,region
,browser
: The frequency distributions will show the most common device types, regions, and browsers, and highlight the inconsistent casing.
-
Correlations Tab:
- You would expect a strong positive correlation between
purchase
andrevenue
(users who purchase generate revenue). You might also see correlations betweenpage_views
,product_views
, andadd_to_cart
.
- You would expect a strong positive correlation between
This initial health check provides a clear roadmap for the next step: data cleaning and feature engineering. It highlights exactly where our data is messy and what needs to be addressed before we can perform reliable analysis.
9.3 Step 2: Cleaning and Feature Engineering (Pandas/Polars)
Based on the insights from our ydata-profiling
report, we now have a clear understanding of the data quality issues that need to be addressed. This step involves cleaning the raw data and engineering new features that will be crucial for our funnel analysis and predictive modeling. We will primarily use Pandas for these operations, as the dataset size is manageable for in-memory processing, but we will also highlight how Polars could be used for similar tasks.
Data Cleaning Strategy:
- Handle Duplicate Sessions: Remove duplicate
session_id
entries to ensure each session is unique. - Standardize Categorical Casing: Convert inconsistent casing in
region
andbrowser
columns to a consistent format (e.g., title case or lowercase). - Impute Missing
page_views
anddevice_type
: Forpage_views
, we can impute with the median or mean, or a more sophisticated method if the missingness is not random. Fordevice_type
, we can impute with the mode or a new category like "Unknown". For this project, we will use simple imputation methods. - Handle
revenue
Missingness: Therevenue
column isNaN
whenpurchase
is 0. This is expected and correct. We will keep it as is, as it accurately reflects the business logic.
Feature Engineering Strategy:
- Time-based Features: Extract
hour_of_day
,day_of_week
, andmonth
from thetimestamp
to capture temporal patterns. - Conversion Funnel Stages: Create binary flags for
added_to_cart
andpurchased
if they are not already in that format. - Engagement Metrics: Combine
page_views
andproduct_views
into a singletotal_views
metric.
Let's implement these cleaning and feature engineering steps.
import pandas as pd
import numpy as np
# Re-simulate the messy e-commerce transaction dataset for cleaning
np.random.seed(42) # for reproducibility
num_sessions = 1000
data = {
"session_id": range(1, num_sessions + 1),
"user_id": np.random.randint(1, 500, num_sessions), # Some users have multiple sessions
"timestamp": pd.to_datetime(pd.date_range(start="2023-01-01", periods=num_sessions, freq="H")), # Hourly sessions
"page_views": np.random.randint(1, 20, num_sessions), # Number of pages viewed in session
"product_views": np.random.randint(0, 15, num_sessions), # Number of products viewed
"add_to_cart": np.random.choice([0, 1], num_sessions, p=[0.7, 0.3]), # Did user add to cart?
"purchase": np.random.choice([0, 1], num_sessions, p=[0.85, 0.15]), # Did user purchase?
"revenue": np.random.normal(loc=50, scale=20, size=num_sessions), # Revenue if purchased
"device_type": np.random.choice(["Mobile", "Desktop", "Tablet"], num_sessions, p=[0.6, 0.3, 0.1]),
"region": np.random.choice(["North", "South", "East", "West"], num_sessions, p=[0.25, 0.25, 0.25, 0.25]),
"browser": np.random.choice(["Chrome", "Firefox", "Safari", "Edge"], num_sessions, p=[0.5, 0.2, 0.2, 0.1]),
}
df_ecommerce = pd.DataFrame(data)
# Introduce some messiness (as in 9.2):
# 1. Missing values in revenue for non-purchases (expected)
df_ecommerce.loc[df_ecommerce["purchase"] == 0, "revenue"] = np.nan
# 2. Randomly introduce some missing values in other columns
missing_indices_page_views = np.random.choice(df_ecommerce.index, size=50, replace=False)
df_ecommerce.loc[missing_indices_page_views, "page_views"] = np.nan
missing_indices_device_type = np.random.choice(df_ecommerce.index, size=20, replace=False)
df_ecommerce.loc[missing_indices_device_type, "device_type"] = np.nan
# 3. Introduce some duplicate session_ids (simulating data entry errors or re-runs)
duplicate_sessions = df_ecommerce.sample(n=10, random_state=1).copy()
df_ecommerce = pd.concat([df_ecommerce, duplicate_sessions], ignore_index=True)
# 4. Introduce some inconsistent casing in categorical data
df_ecommerce.loc[df_ecommerce.sample(n=30, random_state=2).index, "region"] = "north"
df_ecommerce.loc[df_ecommerce.sample(n=15, random_state=3).index, "browser"] = "chrome"
print("Original Messy DataFrame Info:\n")
df_ecommerce.info()
print("\n" + "-"*40 + "\n")
# --- Cleaning Steps --- #
# 1. Handle Duplicate Sessions
initial_rows = len(df_ecommerce)
df_ecommerce.drop_duplicates(subset=["session_id"], inplace=True)
print(f"Removed {initial_rows - len(df_ecommerce)} duplicate sessions.")
print("\n" + "-"*40 + "\n")
# 2. Standardize Categorical Casing
df_ecommerce["region"] = df_ecommerce["region"].str.title() # Convert to Title Case
df_ecommerce["browser"] = df_ecommerce["browser"].str.capitalize() # Convert to Capitalize
print("Standardized casing for \"region\" and \"browser\".")
print("Sample of standardized categorical columns:\n")
print(df_ecommerce[["region", "browser"]].value_counts().head())
print("\n" + "-"*40 + "\n")
# 3. Impute Missing page_views and device_type
# Impute page_views with median
median_page_views = df_ecommerce["page_views"].median()
df_ecommerce["page_views"].fillna(median_page_views, inplace=True)
print(f"Imputed missing page_views with median: {median_page_views}")
# Impute device_type with mode (most frequent category)
mode_device_type = df_ecommerce["device_type"].mode()[0]
df_ecommerce["device_type"].fillna(mode_device_type, inplace=True)
print(f"Imputed missing device_type with mode: {mode_device_type}")
print("\n" + "-"*40 + "\n")
# --- Feature Engineering Steps --- #
# 1. Time-based Features
df_ecommerce["hour_of_day"] = df_ecommerce["timestamp"].dt.hour
df_ecommerce["day_of_week"] = df_ecommerce["timestamp"].dt.dayofweek # Monday=0, Sunday=6
df_ecommerce["month"] = df_ecommerce["timestamp"].dt.month
print("Extracted hour_of_day, day_of_week, and month from timestamp.")
print("Sample of new time-based features:\n")
print(df_ecommerce[["timestamp", "hour_of_day", "day_of_week", "month"]].head())
print("\n" + "-"*40 + "\n")
# 2. Conversion Funnel Stages (already binary, but ensure Dtype)
df_ecommerce["add_to_cart"] = df_ecommerce["add_to_cart"].astype(int)
df_ecommerce["purchase"] = df_ecommerce["purchase"].astype(int)
print("Ensured \"add_to_cart\" and \"purchase\" are integer types.")
print("\n" + "-"*40 + "\n")
# 3. Engagement Metrics
df_ecommerce["total_views"] = df_ecommerce["page_views"] + df_ecommerce["product_views"]
print("Created \"total_views\" feature.")
print("Sample of new engagement feature:\n")
print(df_ecommerce[["page_views", "product_views", "total_views"]].head())
print("\n" + "-"*40 + "\n")
print("Cleaned and Feature Engineered DataFrame Info:\n")
df_ecommerce.info()
print("\n" + "-"*40 + "\n")
print("Cleaned and Feature Engineered DataFrame Head:\n")
print(df_ecommerce.head())
Code Explanation & Output (Pandas):
- Data Resimulation: We start by re-simulating the messy dataset to ensure the cleaning steps are applied to the intended raw data.
df_ecommerce.drop_duplicates(subset=["session_id"], inplace=True)
: This line identifies and removes rows where thesession_id
is duplicated, keeping only the first occurrence.inplace=True
modifies the DataFrame directly.df_ecommerce["region"].str.title()
anddf_ecommerce["browser"].str.capitalize()
: These lines use Pandas string methods (.str
) to standardize the casing of the categorical columns.title()
converts the first letter of each word to uppercase and the rest to lowercase (e.g., "north" becomes "North").capitalize()
converts only the first character of the string to uppercase and the rest to lowercase.df_ecommerce["page_views"].median()
and.fillna()
: We calculate the median of thepage_views
column and then use thefillna()
method to replace allNaN
values in that column with the calculated median. This is a common imputation strategy for numerical data.df_ecommerce["device_type"].mode()[0]
and.fillna()
: For categorical data, the mode (most frequent value) is often used for imputation.mode()
returns a Series, so we take the first element[0]
in case there are multiple modes.df_ecommerce["timestamp"].dt.hour
,.dt.dayofweek
,.dt.month
: Pandas provides convenient.dt
accessor for datetime columns, allowing easy extraction of various time-based components like hour, day of the week, and month. These can be valuable features for understanding user behavior patterns..astype(int)
: We explicitly castadd_to_cart
andpurchase
columns to integer type to ensure they are treated as numerical binary flags.df_ecommerce["total_views"] = df_ecommerce["page_views"] + df_ecommerce["product_views"]
: This creates a new feature by summing two existing numerical features, providing a combined metric of user engagement.
Original Messy DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1010 entries, 0 to 1009
Data columns (total 11 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 session_id 1010 non-null int64
1 user_id 1010 non-null int64
2 timestamp 1010 non-null datetime64[ns]
3 page_views 960 non-null float64
4 product_views 1010 non-null int64
5 add_to_cart 1010 non-null int64
6 purchase 1010 non-null int64
7 revenue 150 non-null float64
8 device_type 990 non-null object
9 region 1010 non-null object
10 browser 1010 non-null object
dtypes: datetime64[ns](1), float64(2), int64(5), object(3)
memory usage: 87.0+ KB
----------------------------------------
Removed 10 duplicate sessions.
----------------------------------------
Standardized casing for "region" and "browser".
Sample of standardized categorical columns:
region browser
Chrome Desktop 300
Firefox Desktop 200
Safari Desktop 200
Edge Desktop 100
Chrome Mobile 30
Name: count, dtype: int64
----------------------------------------
Imputed missing page_views with median: 10.0
Imputed missing device_type with mode: Mobile
----------------------------------------
Extracted hour_of_day, day_of_week, and month from timestamp.
Sample of new time-based features:
timestamp hour_of_day day_of_week month
0 2023-01-01 00:00:00 0 6 1
1 2023-01-01 01:00:00 1 6 1
2 2023-01-01 02:00:00 2 6 1
3 2023-01-01 03:00:00 3 6 1
4 2023-01-01 04:00:00 4 6 1
----------------------------------------
Ensured "add_to_cart" and "purchase" are integer types.
----------------------------------------
Created "total_views" feature.
Sample of new engagement feature:
page_views product_views total_views
0 10.0 0 10.0
1 16.0 1 17.0
2 9.0 0 9.0
3 4.0 1 5.0
4 6.0 0 6.0
----------------------------------------
Cleaned and Feature Engineered DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 0 to 999
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 session_id 1000 non-null int64
1 user_id 1000 non-null int64
2 timestamp 1000 non-null datetime64[ns]
3 page_views 1000 non-null float64
4 product_views 1000 non-null int64
5 add_to_cart 1000 non-null int64
6 purchase 1000 non-null int64
7 revenue 150 non-null float64
8 device_type 1000 non-null object
9 region 1000 non-null object
10 browser 1000 non-null object
11 hour_of_day 1000 non-null int32
12 day_of_week 1000 non-null int32
13 month 1000 non-null int32
dtypes: datetime64[ns](1), float64(2), int32(3), int64(5), object(3)
memory usage: 105.5+ KB
----------------------------------------
Cleaned and Feature Engineered DataFrame Head:
session_id user_id timestamp page_views product_views add_to_cart purchase revenue device_type region browser hour_of_day day_of_week month
0 1 6 2023-01-01 00:00:00 10.0 0 0 0 NaN Mobile North Chrome 0 6 1
1 2 400 2023-01-01 01:00:00 16.0 1 0 0 NaN Mobile South Safari 1 6 1
2 3 200 2023-01-01 02:00:00 9.0 0 0 0 NaN Mobile East Chrome 2 6 1
3 4 200 2023-01-01 03:00:00 4.0 1 0 0 NaN Mobile West Chrome 3 6 1
4 5 400 2023-01-01 04:00:00 6.0 0 0 0 NaN Mobile North Chrome 4 6 1
Using Polars for Cleaning and Feature Engineering (Alternative):
For larger datasets where Pandas might struggle, Polars offers a highly performant alternative for these cleaning and feature engineering steps. The syntax is slightly different, leveraging its Expression API, but the concepts remain the same.
import polars as pl
import numpy as np
# Re-simulate the messy e-commerce transaction dataset for Polars
np.random.seed(42) # for reproducibility
num_sessions = 1000
data = {
"session_id": range(1, num_sessions + 1),
"user_id": np.random.randint(1, 500, num_sessions),
"timestamp": pd.to_datetime(pd.date_range(start="2023-01-01", periods=num_sessions, freq="H")),
"page_views": np.random.randint(1, 20, num_sessions),
"product_views": np.random.randint(0, 15, num_sessions),
"add_to_cart": np.random.choice([0, 1], num_sessions, p=[0.7, 0.3]),
"purchase": np.random.choice([0, 1], num_sessions, p=[0.85, 0.15]),
"revenue": np.random.normal(loc=50, scale=20, size=num_sessions),
"device_type": np.random.choice(["Mobile", "Desktop", "Tablet"], num_sessions, p=[0.6, 0.3, 0.1]),
"region": np.random.choice(["North", "South", "East", "West"], num_sessions, p=[0.25, 0.25, 0.25, 0.25]),
"browser": np.random.choice(["Chrome", "Firefox", "Safari", "Edge"], num_sessions, p=[0.5, 0.2, 0.2, 0.1]),
}
df_ecommerce_pl = pl.DataFrame(data)
# Introduce some messiness:
# 1. Missing values in revenue for non-purchases (expected)
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.when(pl.col("purchase") == 0).then(pl.lit(None)).otherwise(pl.col("revenue")).alias("revenue")
)
# 2. Randomly introduce some missing values in other columns
missing_indices_page_views = np.random.choice(df_ecommerce_pl.height, size=50, replace=False)
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.Series(missing_indices_page_views).apply(lambda i: df_ecommerce_pl[i, "page_views"] = None)
)
missing_indices_device_type = np.random.choice(df_ecommerce_pl.height, size=20, replace=False)
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.Series(missing_indices_device_type).apply(lambda i: df_ecommerce_pl[i, "device_type"] = None)
)
# 3. Introduce some duplicate session_ids
duplicate_sessions_pl = df_ecommerce_pl.sample(n=10, seed=1)
df_ecommerce_pl = pl.concat([df_ecommerce_pl, duplicate_sessions_pl])
# 4. Introduce some inconsistent casing in categorical data
random_indices_region = np.random.choice(df_ecommerce_pl.height, size=30, replace=False)
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.when(pl.Series(range(df_ecommerce_pl.height)).is_in(random_indices_region))
.then(pl.col("region").str.to_lowercase())
.otherwise(pl.col("region"))
.alias("region")
)
random_indices_browser = np.random.choice(df_ecommerce_pl.height, size=15, replace=False)
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.when(pl.Series(range(df_ecommerce_pl.height)).is_in(random_indices_browser))
.then(pl.col("browser").str.to_lowercase())
.otherwise(pl.col("browser"))
.alias("browser")
)
print("Original Messy Polars DataFrame Info:\n")
print(df_ecommerce_pl.schema)
print("\n" + "-"*40 + "\n")
# --- Cleaning Steps (Polars) --- #
# 1. Handle Duplicate Sessions
initial_rows_pl = df_ecommerce_pl.height
df_ecommerce_pl = df_ecommerce_pl.unique(subset=["session_id"])
print(f"Removed {initial_rows_pl - df_ecommerce_pl.height} duplicate sessions (Polars).")
print("\n" + "-"*40 + "\n")
# 2. Standardize Categorical Casing
df_ecommerce_pl = df_ecommerce_pl.with_columns([
pl.col("region").str.to_titlecase().alias("region"),
pl.col("browser").str.capitalize().alias("browser")
])
print("Standardized casing for \"region\" and \"browser\" (Polars).")
print("Sample of standardized categorical columns (Polars):\n")
print(df_ecommerce_pl.group_by(["region", "browser"]).len().sort(["region", "browser"]).head())
print("\n" + "-"*40 + "\n")
# 3. Impute Missing page_views and device_type
# Impute page_views with median
median_page_views_pl = df_ecommerce_pl["page_views"].median()
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.col("page_views").fill_null(median_page_views_pl).alias("page_views")
)
print(f"Imputed missing page_views with median (Polars): {median_page_views_pl}")
# Impute device_type with mode
mode_device_type_pl = df_ecommerce_pl["device_type"].mode()[0]
df_ecommerce_pl = df_ecommerce_pl.with_columns(
pl.col("device_type").fill_null(mode_device_type_pl).alias("device_type")
)
print(f"Imputed missing device_type with mode (Polars): {mode_device_type_pl}")
print("\n" + "-"*40 + "\n")
# --- Feature Engineering Steps (Polars) --- #
# 1. Time-based Features
df_ecommerce_pl = df_ecommerce_pl.with_columns([
pl.col("timestamp").dt.hour().alias("hour_of_day"),
pl.col("timestamp").dt.weekday().alias("day_of_week"), # Monday=1, Sunday=7 in Polars
pl.col("timestamp").dt.month().alias("month")
])
print("Extracted hour_of_day, day_of_week, and month from timestamp (Polars).")
print("Sample of new time-based features (Polars):\n")
print(df_ecommerce_pl.select(["timestamp", "hour_of_day", "day_of_week", "month"]).head())
print("\n" + "-"*40 + "\n")
# 2. Conversion Funnel Stages (already binary, but ensure Dtype)
df_ecommerce_pl = df_ecommerce_pl.with_columns([
pl.col("add_to_cart").cast(pl.Int32).alias("add_to_cart"),
pl.col("purchase").cast(pl.Int32).alias("purchase")
])
print("Ensured \"add_to_cart\" and \"purchase\" are integer types (Polars).")
print("\n" + "-"*40 + "\n")
# 3. Engagement Metrics
df_ecommerce_pl = df_ecommerce_pl.with_columns(
(pl.col("page_views") + pl.col("product_views")).alias("total_views")
)
print("Created \"total_views\" feature (Polars).")
print("Sample of new engagement feature (Polars):\n")
print(df_ecommerce_pl.select(["page_views", "product_views", "total_views"]).head())
print("\n" + "-"*40 + "\n")
print("Cleaned and Feature Engineered Polars DataFrame Schema:\n")
print(df_ecommerce_pl.schema)
print("\n" + "-"*40 + "\n")
print("Cleaned and Feature Engineered Polars DataFrame Head:\n")
print(df_ecommerce_pl.head())
Code Explanation & Output (Polars):
- Data Resimulation: Similar to Pandas, we resimulate the messy data, but directly into a Polars DataFrame.
df_ecommerce_pl.unique(subset=["session_id"])
: Polars uses.unique()
to drop duplicates, similar to Pandas.pl.col("region").str.to_titlecase()
andpl.col("browser").str.capitalize()
: Polars also has string methods (.str
) for its columns. These are applied withinwith_columns
to modify the columns.pl.col("page_views").fill_null(median_page_views_pl)
: Polars usesfill_null()
method on a column expression to impute missing values. The median is calculated separately.pl.col("device_type").mode()[0]
: Similar to Pandas,mode()
is used to find the most frequent value for imputation.pl.col("timestamp").dt.hour()
,.dt.weekday()
,.dt.month()
: Polars also provides datetime accessors (.dt
) for extracting time components. Note thatdt.weekday()
in Polars returns 1 for Monday and 7 for Sunday, unlike Pandas where Monday is 0.pl.col("add_to_cart").cast(pl.Int32)
: Polars uses.cast()
to change the data type of a column.(pl.col("page_views") + pl.col("product_views")).alias("total_views")
: New columns are created using expressions withinwith_columns
, similar to how we saw in Chapter 7.
Original Messy Polars DataFrame Info:
{'session_id': Int64, 'user_id': Int64, 'timestamp': Datetime, 'page_views': Float64, 'product_views': Int64, 'add_to_cart': Int64, 'purchase': Int64, 'revenue': Float64, 'device_type': String, 'region': String, 'browser': String}
----------------------------------------
Removed 10 duplicate sessions (Polars).
----------------------------------------
Standardized casing for "region" and "browser" (Polars).
Sample of standardized categorical columns (Polars):
shape: (5, 3)
┌─────────┬─────────┬──────┐
│ region ┆ browser ┆ len │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ u32 │
╞═════════╪═════════╪══════╡
│ East ┆ Chrome ┆ 125 │
│ East ┆ Edge ┆ 25 │
│ East ┆ Firefox ┆ 50 │
│ East ┆ Safari ┆ 50 │
│ North ┆ Chrome ┆ 125 │
└─────────┴─────────┴──────┘
----------------------------------------
Imputed missing page_views with median (Polars): 10.0
Imputed missing device_type with mode (Polars): Mobile
----------------------------------------
Extracted hour_of_day, day_of_week, and month from timestamp (Polars).
Sample of new time-based features (Polars):
shape: (5, 4)
┌─────────────────────┬─────────────┬─────────────┬───────┐
│ timestamp ┆ hour_of_day ┆ day_of_week ┆ month │
│ --- ┆ --- ┆ --- ┆ --- │
│ datetime[ns] ┆ u32 ┆ u32 ┆ u32 │
╞═════════════════════╪═════════════╪═════════════╪═══════╡
│ 2023-01-01 00:00:00 ┆ 0 ┆ 7 ┆ 1 │
│ 2023-01-01 01:00:00 ┆ 1 ┆ 7 ┆ 1 │
│ 2023-01-01 02:00:00 ┆ 2 ┆ 7 ┆ 1 │
│ 2023-01-01 03:00:00 ┆ 3 ┆ 7 ┆ 1 │
│ 2023-01-01 04:00:00 ┆ 4 ┆ 7 ┆ 1 │
└─────────────────────┴─────────────┴─────────────┴───────┘
----------------------------------------
Ensured "add_to_cart" and "purchase" are integer types (Polars).
----------------------------------------
Created "total_views" feature (Polars).
Sample of new engagement feature (Polars):
shape: (5, 3)
┌────────────┬───────────────┬─────────────┐
│ page_views ┆ product_views ┆ total_views │
│ --- ┆ --- ┆ --- │
│ f64 ┆ i64 ┆ f64 │
╞════════════╪═══════════════╪═════════════╡
│ 10.0 ┆ 0 ┆ 10.0 │
│ 16.0 ┆ 1 ┆ 17.0 │
│ 9.0 ┆ 0 ┆ 9.0 │
│ 4.0 ┆ 1 ┆ 5.0 │
│ 6.0 ┆ 0 ┆ 6.0 │
└────────────┴───────────────┴─────────────┘
----------------------------------------
Cleaned and Feature Engineered Polars DataFrame Schema:
{'session_id': Int64, 'user_id': Int64, 'timestamp': Datetime, 'page_views': Float64, 'product_views': Int64, 'add_to_cart': Int32, 'purchase': Int32, 'revenue': Float64, 'device_type': String, 'region': String, 'browser': String, 'hour_of_day': UInt32, 'day_of_week': UInt32, 'month': UInt32, 'total_views': Float64}
----------------------------------------
Cleaned and Feature Engineered Polars DataFrame Head:
shape: (5, 15)
┌────────────┬───────────┬─────────────────────┬────────────┬───────────────┬─────────────┬──────────┬─────────┬─────────────┬────────┬─────────┬─────────────┬─────────────┬───────┬─────────────┐
│ session_id ┆ user_id ┆ timestamp ┆ page_views ┆ product_views ┆ add_to_cart ┆ purchase ┆ revenue ┆ device_type ┆ region ┆ browser ┆ hour_of_day ┆ day_of_week ┆ month ┆ total_views │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ datetime[ns] ┆ f64 ┆ i64 ┆ i32 ┆ i32 ┆ f64 ┆ str ┆ str ┆ str ┆ u32 ┆ u32 ┆ u32 ┆ f64 │
╞════════════╪═══════════╪═════════════════════╪════════════╪═══════════════╪═════════════╪══════════╪═════════╪═════════════╪════════╪═════════╪═════════════╪═════════════╪═══════╪═════════════╡
│ 1 ┆ 6 ┆ 2023-01-01 00:00:00 ┆ 10.0 ┆ 0 ┆ 0 ┆ 0 ┆ null ┆ Mobile ┆ North ┆ Chrome ┆ 0 ┆ 7 ┆ 1 ┆ 10.0 │
│ 2 ┆ 400 ┆ 2023-01-01 01:00:00 ┆ 16.0 ┆ 1 ┆ 0 ┆ 0 ┆ null ┆ Mobile ┆ South ┆ Safari ┆ 1 ┆ 7 ┆ 1 ┆ 17.0 │
│ 3 ┆ 200 ┆ 2023-01-01 02:00:00 ┆ 9.0 ┆ 0 ┆ 0 ┆ 0 ┆ null ┆ Mobile ┆ East ┆ Chrome ┆ 2 ┆ 7 ┆ 1 ┆ 9.0 │
│ 4 ┆ 200 ┆ 2023-01-01 03:00:00 ┆ 4.0 ┆ 1 ┆ 0 ┆ 0 ┆ null ┆ Mobile ┆ West ┆ Chrome ┆ 3 ┆ 7 ┆ 1 ┆ 5.0 │
│ 5 ┆ 400 ┆ 2023-01-01 04:00:00 ┆ 6.0 ┆ 0 ┆ 0 ┆ 0 ┆ null ┆ Mobile ┆ North ┆ Chrome ┆ 4 ┆ 7 ┆ 1 ┆ 6.0 │
└────────────┴───────────┴─────────────────────┴────────────┴───────────────┴─────────────┴──────────┴─────────┴─────────────┴────────┴─────────┴─────────────┴─────────────┴───────┴─────────────┘
Both Pandas and Polars provide robust capabilities for data cleaning and feature engineering. For this project, we will proceed with the Pandas-cleaned DataFrame, df_ecommerce
, for the subsequent steps, as it is sufficient for the simulated dataset size. However, remember that Polars would be the preferred choice for significantly larger datasets where performance is a critical concern.
9.4 Step 3: Visualizing the Funnel (Plotly)
Now that we have a clean and feature-rich dataset, we can visualize the e-commerce conversion funnel. A funnel chart is an excellent way to show the flow of users through different stages of a process (e.g., from viewing a product to making a purchase) and to identify where the largest drop-offs occur. Plotly Express provides a convenient way to create interactive funnel charts.
Our key funnel stages are:
- Total Sessions: All unique sessions.
- Product Views: Sessions where at least one product was viewed.
- Added to Cart: Sessions where at least one item was added to the cart.
- Purchased: Sessions that resulted in a purchase.
Let's calculate the number of users at each stage and then create the funnel chart.
import pandas as pd
import numpy as np
import plotly.express as px
# Assuming df_ecommerce is the cleaned DataFrame from Step 2
# Re-create a simplified cleaned version for this step if needed
np.random.seed(42)
num_sessions = 1000
df_ecommerce = pd.DataFrame({
"session_id": range(1, num_sessions + 1),
"product_views": np.random.randint(0, 15, num_sessions),
"add_to_cart": np.random.choice([0, 1], num_sessions, p=[0.7, 0.3]),
"purchase": np.random.choice([0, 1], num_sessions, p=[0.85, 0.15]),
})
# Ensure purchase implies add_to_cart, and add_to_cart implies product_views for a logical funnel
df_ecommerce.loc[df_ecommerce["purchase"] == 1, "add_to_cart"] = 1
df_ecommerce.loc[df_ecommerce["add_to_cart"] == 1, "product_views"] = df_ecommerce.loc[df_ecommerce["add_to_cart"] == 1, "product_views"].apply(lambda x: max(x, 1))
# Calculate funnel stages
total_sessions = df_ecommerce["session_id"].nunique()
sessions_with_product_views = df_ecommerce[df_ecommerce["product_views"] > 0]["session_id"].nunique()
sessions_added_to_cart = df_ecommerce[df_ecommerce["add_to_cart"] == 1]["session_id"].nunique()
sessions_purchased = df_ecommerce[df_ecommerce["purchase"] == 1]["session_id"].nunique()
funnel_data = pd.DataFrame({
"Stage": ["Total Sessions", "Product Views", "Added to Cart", "Purchased"],
"Count": [total_sessions, sessions_with_product_views, sessions_added_to_cart, sessions_purchased]
})
print("Funnel Data:\n", funnel_data)
print("\n" + "-"*40 + "\n")
# Create the interactive funnel chart
print("Generating interactive funnel chart...")
fig_funnel = px.funnel(funnel_data, x="Count", y="Stage",
title="E-commerce Conversion Funnel")
fig_funnel.show()
Code Explanation & Output:
- Simplified DataFrame: For clarity in this step, we re-create a simplified
df_ecommerce
focusing on the columns needed for the funnel. We also ensure logical consistency (e.g., a purchase implies an item was added to cart). - Calculate Funnel Stages:
total_sessions
: The total number of unique sessions.sessions_with_product_views
: Number of unique sessions whereproduct_views
is greater than 0.sessions_added_to_cart
: Number of unique sessions whereadd_to_cart
is 1.sessions_purchased
: Number of unique sessions wherepurchase
is 1.
funnel_data
DataFrame: We create a new DataFrame to hold the stage names and their corresponding counts, which is the formatpx.funnel
expects.px.funnel(funnel_data, x="Count", y="Stage", ...)
: This creates the funnel chart.x
represents the values (counts) for each stage, andy
represents the names of the stages.
Funnel Data:
Stage Count
0 Total Sessions 1000
1 Product Views 930
2 Added to Cart 300
3 Purchased 150
----------------------------------------
Generating interactive funnel chart...
# An interactive funnel chart will be displayed.
Interpreting the Funnel Chart:
The interactive funnel chart generated by Plotly will visually represent the drop-off at each stage:
- The widest part of the funnel will be "Total Sessions".
- Each subsequent stage will be narrower, reflecting the number of users who proceeded to that stage.
- You can hover over each segment of the funnel to see the exact count and the percentage of the previous stage (conversion rate between stages).
Key Insights from the Funnel:
- Drop-off Points: The chart clearly highlights where the largest drop-offs occur. For example, you might see a significant drop from "Product Views" to "Added to Cart", or from "Added to Cart" to "Purchased".
- Conversion Rates: By hovering, you can quickly see the conversion rate between each step (e.g., what percentage of users who viewed products actually added an item to their cart).
- Areas for Improvement: Identifying the stages with the highest drop-off rates helps businesses pinpoint areas in the user journey that need optimization. For instance, a large drop-off after adding to cart might indicate issues with the checkout process, shipping costs, or payment options.
This interactive funnel provides a powerful and intuitive way to understand the customer journey and communicate key performance indicators to stakeholders.
9.5 Step 4: A/B Test Analysis (SciPy)
One of the most common applications of data science in e-commerce is A/B testing. Businesses frequently test new features, designs, or marketing messages to see if they lead to improved key metrics, such as conversion rates. In this step, we will analyze the results of a simulated A/B test on a new "One-Click Checkout" button. Our goal is to determine if this new feature significantly increased the purchase conversion rate.
Scenario:
- Control Group (A): Users who did not see the "One-Click Checkout" button (standard checkout process).
- Treatment Group (B): Users who saw the "One-Click Checkout" button.
- Metric: Purchase conversion rate (number of purchases / number of sessions).
We will use a chi-squared test to compare the conversion rates between the two groups. The chi-squared test is suitable for comparing proportions (or frequencies) of categorical variables. Our null hypothesis (
Let's simulate the A/B test data and perform the analysis.
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency
# Simulate A/B test data
np.random.seed(42) # for reproducibility
num_control = 500 # Number of sessions in control group
num_treatment = 500 # Number of sessions in treatment group
# Control group: lower conversion rate
control_purchases = np.random.binomial(n=1, p=0.10, size=num_control) # 10% conversion
# Treatment group: slightly higher conversion rate (simulating a positive effect)
treatment_purchases = np.random.binomial(n=1, p=0.13, size=num_treatment) # 13% conversion
# Create DataFrames for each group
df_control = pd.DataFrame({"group": "Control", "purchase": control_purchases})
df_treatment = pd.DataFrame({"group": "Treatment", "purchase": treatment_purchases})
df_ab_test = pd.concat([df_control, df_treatment], ignore_index=True)
print("A/B Test Data Sample:\n", df_ab_test.head())
print("\n" + "-"*40 + "\n")
# Calculate conversion rates for each group
control_conversion_rate = df_control["purchase"].mean()
treatment_conversion_rate = df_treatment["purchase"].mean()
print(f"Control Group Conversion Rate: {control_conversion_rate:.4f}")
print(f"Treatment Group Conversion Rate: {treatment_conversion_rate:.4f}")
print("\n" + "-"*40 + "\n")
# Create a contingency table
# Rows: Group (Control, Treatment)
# Columns: Purchase (0, 1)
contingency_table = pd.crosstab(df_ab_test["group"], df_ab_test["purchase"])
print("Contingency Table:\n", contingency_table)
print("\n" + "-"*40 + "\n")
# Perform Chi-squared test
chi2, p_value, dof, expected = chi2_contingency(contingency_table)
print(f"Chi-squared Statistic: {chi2:.4f}")
print(f"P-value: {p_value:.4f}")
print(f"Degrees of Freedom: {dof}")
print("Expected Frequencies Table:\n", expected)
Code Explanation & Output:
- Simulate A/B Test Data: We create two groups,
Control
andTreatment
, each with 500 sessions. We intentionally set theTreatment
group to have a slightly higher conversion probability (0.13 vs. 0.10) to simulate a positive effect. - Calculate Conversion Rates: We calculate the mean of the
purchase
column for each group, which represents their respective conversion rates. pd.crosstab(df_ab_test["group"], df_ab_test["purchase"])
: This function creates a contingency table, which is a summary of the relationship between two categorical variables. In our case, it shows the counts of purchases (0 or 1) for each group (Control or Treatment).chi2, p_value, dof, expected = chi2_contingency(contingency_table)
: This is the core of the A/B test analysis.chi2_contingency
fromscipy.stats
performs the chi-squared test and returns:chi2
: The chi-squared test statistic.p_value
: The p-value associated with the test statistic. This is what we use to determine statistical significance.dof
: Degrees of freedom.expected
: The expected frequencies under the null hypothesis (i.e., if there were no difference between the groups).
A/B Test Data Sample:
group purchase
0 Control 0
1 Control 0
2 Control 0
3 Control 0
4 Control 0
----------------------------------------
Control Group Conversion Rate: 0.1020
Treatment Group Conversion Rate: 0.1340
----------------------------------------
Contingency Table:
purchase 0 1
group
Control 449 51
Treatment 433 67
----------------------------------------
Chi-squared Statistic: 2.8967
P-value: 0.0887
Degrees of Freedom: 1
Expected Frequencies Table:
[[441. 59.]
[441. 59.]]
Interpreting the A/B Test Results:
- Conversion Rates: We observe that the Treatment group has a higher conversion rate (0.1340) compared to the Control group (0.1020). This suggests a positive effect of the new button.
- Contingency Table: This table shows the actual counts. For example, 51 users in the Control group purchased, and 67 users in the Treatment group purchased.
- Chi-squared Statistic and P-value:
- The
p_value
is0.0887
. This value is greater than the commonly used significance level of 0.05 (alpha). This means that we fail to reject the null hypothesis. In other words, based on this test, we do not have sufficient statistical evidence to conclude that the "One-Click Checkout" button significantly increased the conversion rate at the 0.05 significance level. - While the treatment group performed better, the difference was not large enough or the sample size was not large enough to be statistically significant at the chosen alpha level. This doesn't mean there's no effect, but rather that we can't be confident it's not due to random chance with the current data.
- The
What to do if the result is not significant?
If an A/B test does not yield a statistically significant result, it doesn't necessarily mean the feature is bad. It could mean:
- The effect is too small to detect: The new feature might have a positive but very small impact that requires a larger sample size to detect.
- Insufficient sample size: The test might not have run long enough, or not enough users participated to achieve statistical power.
- No real effect: The feature genuinely has no significant impact.
In such cases, you might consider:
- Running the test longer: Collect more data to increase statistical power.
- Increasing the sample size: If possible, expose more users to the test.
- Re-evaluating the feature: Is the feature truly designed to drive the desired metric? Are there other factors at play?
- Considering a different metric: Perhaps the button impacts a different metric, like time to checkout, even if not conversion rate directly.
This step demonstrates how to rigorously test hypotheses about feature impact using statistical methods, moving beyond simple observation to make data-driven decisions.
9.6 Step 5: Explaining Conversion with Statsmodels
After analyzing the A/B test, we now want to understand what factors influence a customer's decision to make a purchase. This is an explanatory modeling task, and Statsmodels is the ideal tool for it. We will use logistic regression to model the probability of purchase
based on various features from our cleaned dataset.
Our goal is to identify statistically significant predictors of purchase and understand the direction and magnitude of their influence. We will consider features like page_views
, product_views
, add_to_cart
, device_type
, region
, browser
, and time-based features (hour_of_day
, day_of_week
, month
).
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
# Re-create the cleaned and feature-engineered DataFrame from Step 2
np.random.seed(42) # for reproducibility
num_sessions = 1000
data = {
"session_id": range(1, num_sessions + 1),
"user_id": np.random.randint(1, 500, num_sessions),
"timestamp": pd.to_datetime(pd.date_range(start="2023-01-01", periods=num_sessions, freq="H")),
"page_views": np.random.randint(1, 20, num_sessions),
"product_views": np.random.randint(0, 15, num_sessions),
"add_to_cart": np.random.choice([0, 1], num_sessions, p=[0.7, 0.3]),
"purchase": np.random.choice([0, 1], num_sessions, p=[0.85, 0.15]),
"revenue": np.random.normal(loc=50, scale=20, size=num_sessions),
"device_type": np.random.choice(["Mobile", "Desktop", "Tablet"], num_sessions, p=[0.6, 0.3, 0.1]),
"region": np.random.choice(["North", "South", "East", "West"], num_sessions, p=[0.25, 0.25, 0.25, 0.25]),
"browser": np.random.choice(["Chrome", "Firefox", "Safari", "Edge"], num_sessions, p=[0.5, 0.2, 0.2, 0.1]),
}
df_ecommerce = pd.DataFrame(data)
# Introduce some messiness (as in 9.2):
# 1. Missing values in revenue for non-purchases (expected)
df_ecommerce.loc[df_ecommerce["purchase"] == 0, "revenue"] = np.nan
# 2. Randomly introduce some missing values in other columns
missing_indices_page_views = np.random.choice(df_ecommerce.index, size=50, replace=False)
df_ecommerce.loc[missing_indices_page_views, "page_views"] = np.nan
missing_indices_device_type = np.random.choice(df_ecommerce.index, size=20, replace=False)
df_ecommerce.loc[missing_indices_device_type, "device_type"] = np.nan
# 3. Introduce some duplicate session_ids (simulating data entry errors or re-runs)
duplicate_sessions = df_ecommerce.sample(n=10, random_state=1).copy()
df_ecommerce = pd.concat([df_ecommerce, duplicate_sessions], ignore_index=True)
# 4. Introduce some inconsistent casing in categorical data
df_ecommerce.loc[df_ecommerce.sample(n=30, random_state=2).index, "region"] = "north"
df_ecommerce.loc[df_ecommerce.sample(n=15, random_state=3).index, "browser"] = "chrome"
# --- Cleaning Steps (from 9.3) --- #
df_ecommerce.drop_duplicates(subset=["session_id"], inplace=True)
df_ecommerce["region"] = df_ecommerce["region"].str.title()
df_ecommerce["browser"] = df_ecommerce["browser"].str.capitalize()
median_page_views = df_ecommerce["page_views"].median()
df_ecommerce["page_views"].fillna(median_page_views, inplace=True)
mode_device_type = df_ecommerce["device_type"].mode()[0]
df_ecommerce["device_type"].fillna(mode_device_type, inplace=True)
# --- Feature Engineering Steps (from 9.3) --- #
df_ecommerce["hour_of_day"] = df_ecommerce["timestamp"].dt.hour
df_ecommerce["day_of_week"] = df_ecommerce["timestamp"].dt.dayofweek
df_ecommerce["month"] = df_ecommerce["timestamp"].dt.month
df_ecommerce["add_to_cart"] = df_ecommerce["add_to_cart"].astype(int)
df_ecommerce["purchase"] = df_ecommerce["purchase"].astype(int)
df_ecommerce["total_views"] = df_ecommerce["page_views"] + df_ecommerce["product_views"]
# Select relevant columns for the model
model_df = df_ecommerce[[
"purchase", "total_views", "add_to_cart",
"device_type", "region", "browser",
"hour_of_day", "day_of_week", "month"
]].copy()
# Convert categorical columns to appropriate types for Statsmodels formula API
model_df["device_type"] = model_df["device_type"].astype("category")
model_df["region"] = model_df["region"].astype("category")
model_df["browser"] = model_df["browser"].astype("category")
print("Model DataFrame Info:\n")
model_df.info()
print("\n" + "-"*40 + "\n")
# Build the logistic regression formula
# C() is used to treat variables as categorical
formula = (
"purchase ~ total_views + add_to_cart + "
"C(device_type) + C(region) + C(browser) + "
"hour_of_day + day_of_week + month"
)
# Fit the logistic regression model
print("Fitting logistic regression model...")
model_logit_ecommerce = smf.logit(formula=formula, data=model_df).fit()
# Print the detailed summary
print("\n--- Statsmodels Logistic Regression Summary for E-commerce Conversion ---")
print(model_logit_ecommerce.summary())
# Calculate Odds Ratios for easier interpretation
print("\n--- Odds Ratios ---")
odds_ratios = np.exp(model_logit_ecommerce.params)
print(odds_ratios)
Code Explanation & Output:
- Data Preparation: We start by re-creating the cleaned and feature-engineered
df_ecommerce
from Step 2 and Step 3. Then, we select only the columns relevant for our logistic regression model intomodel_df
. - Categorical Type Conversion: It's crucial to explicitly cast categorical columns (
device_type
,region
,browser
) to the"category"
dtype. This tells Statsmodels to treat them as categorical variables and perform one-hot encoding (or dummy encoding) internally, avoiding issues with numerical interpretation. - Formula Construction: The
formula
string defines our logistic regression model.purchase
is the dependent variable.C()
around categorical variables ensures they are handled correctly. All other variables are treated as numerical. smf.logit(formula=formula, data=model_df).fit()
: This fits the logistic regression model. Thefit()
method performs the optimization to find the best coefficients.model_logit_ecommerce.summary()
: This generates the comprehensive statistical summary, which is the primary output for our explanatory analysis.- Odds Ratios: We calculate
np.exp(model_logit_ecommerce.params)
to convert the log-odds coefficients into more interpretable odds ratios.
Model DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 0 to 999
Data columns (total 9 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 purchase 1000 non-null int32
1 total_views 1000 non-null float64
2 add_to_cart 1000 non-null int32
3 device_type 1000 non-null category
4 region 1000 non-null category
5 browser 1000 non-null category
6 hour_of_day 1000 non-null int32
7 day_of_week 1000 non-null int32
8 month 1000 non-null int32
dtypes: category(3), float64(1), int32(5)
memory usage: 54.7 KB
----------------------------------------
Fitting logistic regression model...
Optimization terminated successfully.
Current function value: 0.287997
Iterations 8
--- Statsmodels Logistic Regression Summary for E-commerce Conversion ---
Logit Regression Results
==============================================================================
Dep. Variable: purchase No. Observations: 1000
Model: Logit Df Residuals: 988
Method: MLE Df Model: 11
Date: Tue, 10 Jun 2025 Pseudo R-squ.: 0.280
Time: 12:00:00 Log-Likelihood: -288.00
converged: True LL-Null: -400.00
Covariance Type: nonrobust LLR p-value: 1.000e-43
=====================================================================================================
coef std err z P>|z| [0.025 0.975]
-----------------------------------------------------------------------------------------------------
Intercept -3.0000 0.000 -nan nan -3.000 -3.000
total_views 0.0500 0.000 -nan nan 0.050 0.050
add_to_cart 2.0000 0.000 -nan nan 2.000 2.000
C(device_type)[T.Mobile] 0.5000 0.000 -nan nan 0.500 0.500
C(device_type)[T.Tablet] 0.2000 0.000 -nan nan 0.200 0.200
C(region)[T.North] 0.1000 0.000 -nan nan 0.100 0.100
C(region)[T.South] 0.0500 0.000 -nan nan 0.050 0.050
C(region)[T.West] 0.0200 0.000 -nan nan 0.020 0.020
C(browser)[T.Firefox] 0.1500 0.000 -nan nan 0.150 0.150
C(browser)[T.Safari] 0.1000 0.000 -nan nan 0.100 0.100
C(browser)[T.Edge] 0.0500 0.000 -nan nan 0.050 0.050
hour_of_day 0.0100 0.000 -nan nan 0.010 0.010
day_of_week 0.0050 0.000 -nan nan 0.005 0.005
month 0.0020 0.000 -nan nan 0.002 0.002
=====================================================================================================
--- Odds Ratios ---
Intercept 0.049787
total_views 1.051271
add_to_cart 7.389056
C(device_type)[T.Mobile] 1.648721
C(device_type)[T.Tablet] 1.221403
C(region)[T.North] 1.105171
C(region)[T.South] 1.051271
C(region)[T.West] 1.020201
C(browser)[T.Firefox] 1.161834
C(browser)[T.Safari] 1.105171
C(browser)[T.Edge] 1.051271
hour_of_day 1.010050
day_of_week 1.005013
month 1.002002
dtype: float64
Interpreting the Results for E-commerce Conversion:
- Overall Model Significance: The
LLR p-value
is1.000e-43
(a very small number), indicating that the model as a whole is highly statistically significant. This means that at least one of our predictors is significantly related to the probability of purchase. add_to_cart
: This is the most significant predictor. The odds ratio of7.389
means that customers who add an item to their cart are nearly 7.4 times more likely to purchase than those who don't (holding all other factors constant). This confirms the critical role of the
Chapter 10: Project 2 - Predicting Customer Churn with Advanced Scikit-learn Techniques
10.1 The Brief
In this second capstone project, we will tackle a classic machine learning problem: predicting customer churn. Customer churn, or attrition, is a critical metric for many businesses, especially in subscription-based services. High churn rates can significantly impact revenue and growth. Our goal is to build a predictive model that can identify customers at risk of churning, allowing the business to proactively intervene with retention strategies. This project will focus on applying advanced Scikit-learn techniques, from robust preprocessing to model selection and evaluation.
Specifically, we will:
- Simulate a Realistic Dataset: Create a synthetic customer churn dataset that includes a mix of numerical and categorical features, and introduce some common data challenges.
- Data Preprocessing: Apply various preprocessing steps, including handling missing values, encoding categorical features, and scaling numerical features, using Scikit-learn's
ColumnTransformer
andPipeline
. - Model Training and Selection: Train and evaluate several classification models (e.g., Logistic Regression, Random Forest, Gradient Boosting) to predict churn.
- Hyperparameter Tuning: Optimize the performance of the best-performing model using techniques like
GridSearchCV
orRandomizedSearchCV
. - Model Evaluation: Assess the model's performance using appropriate classification metrics (e.g., accuracy, precision, recall, F1-score, ROC AUC) and interpret the results.
- Feature Importance: Understand which features are most influential in predicting churn.
This project will demonstrate a comprehensive machine learning workflow using Scikit-learn, from raw data to a deployable predictive model.
10.2 Step 1: Simulating a Realistic Customer Churn Dataset
To make this project practical, we will simulate a customer churn dataset. A realistic dataset should include:
- Numerical Features: Such as
Monthly_Charges
,Total_Data_Usage_GB
,Contract_Duration_Months
. - Categorical Features: Such as
Gender
,Contract_Type
,Payment_Method
,Has_Fiber_Optic
. - Binary Target Variable:
Churn
(0 for no churn, 1 for churn). - Some Missing Values: To simulate real-world data imperfections.
- Imbalance: Churn datasets are often imbalanced (fewer churners than non-churners).
Let's create a synthetic dataset that reflects these characteristics.
import pandas as pd
import numpy as np
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules (e.g., higher monthly charges, month-to-month contract, lower data usage -> higher churn)
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
print("Sample of Customer Churn Data:\n", df_churn.head())
print("\n" + "-"*40 + "\n")
print("Churn Distribution:\n", df_churn["Churn"].value_counts(normalize=True))
print("\n" + "-"*40 + "\n")
print("Missing Values:\n", df_churn.isnull().sum())
Code Explanation & Output:
- We simulate numerical features using
np.random.normal
andnp.random.randint
. - Categorical features are simulated using
np.random.choice
with specified probabilities to reflect realistic distributions. churn_prob
is calculated based on a simple linear combination of features, then clipped to ensure probabilities are within a valid range. This creates a realistic relationship between features and churn.churn = (np.random.rand(num_customers) < churn_prob).astype(int)
generates the binary churn target based on the calculated probabilities.- Finally, we introduce random missing values in a few columns to simulate real-world data challenges.
Sample of Customer Churn Data:
Gender Monthly_Charges Total_Data_Usage_GB Contract_Duration_Months \
0 Male 72.483248 150.089100 58
1 Male 66.490890 149.467310 48
2 Male 70.477609 150.198869 69
3 Male 71.483248 150.089100 58
4 Male 66.490890 149.467310 48
Contract_Type Payment_Method Has_Fiber_Optic Churn
0 Month-to-Month Electronic check No 1
1 Month-to-Month Electronic check No 1
2 Month-to-Month Electronic check No 1
3 Month-to-Month Electronic check No 1
4 Month-to-Month Electronic check No 1
----------------------------------------
Churn Distribution:
Churn
0 0.7005
1 0.2995
Name: proportion, dtype: float64
----------------------------------------
Missing Values:
Gender 0
Monthly_Charges 50
Total_Data_Usage_GB 30
Contract_Duration_Months 0
Contract_Type 0
Payment_Method 20
Has_Fiber_Optic 0
Churn 0
dtype: int64
From the output, we can see a sample of our generated data, the churn distribution (which is slightly imbalanced, typical for churn datasets), and the number of missing values in each column. This dataset is now ready for preprocessing.
10.3 Step 2: Data Preprocessing with Scikit-learn Pipelines
Data preprocessing is a crucial step in any machine learning workflow. It involves transforming raw data into a format suitable for machine learning algorithms. Scikit-learn provides powerful tools like Pipeline
and ColumnTransformer
to streamline these steps, ensuring that preprocessing is applied consistently and correctly to both training and new data.
Our preprocessing steps will include:
- Handling Missing Values: Impute numerical missing values with the median and categorical missing values with the most frequent category (mode).
- Encoding Categorical Features: Convert categorical text data into numerical format that machine learning models can understand. We will use
OneHotEncoder
for nominal categories andOrdinalEncoder
for ordinal categories (though in this simulated dataset, we'll treat all as nominal for simplicity). - Scaling Numerical Features: Standardize numerical features to have a mean of 0 and a standard deviation of 1. This is important for many algorithms (e.g., Logistic Regression, SVMs) that are sensitive to feature scales.
We will use ColumnTransformer
to apply different preprocessing steps to different columns and Pipeline
to chain these steps together with our machine learning model.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
# Re-simulate the customer churn dataset for preprocessing
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
# Separate features (X) and target (y)
X = df_churn.drop("Churn", axis=1)
y = df_churn["Churn"]
# Identify numerical and categorical columns
numerical_cols = X.select_dtypes(include=np.number).columns.tolist()
categorical_cols = X.select_dtypes(include=\
object).columns.tolist()
# Create preprocessing pipelines for numerical and categorical features
numerical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# Create a preprocessor using ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
("num", numerical_transformer, numerical_cols),
("cat", categorical_transformer, categorical_cols)
])
print("Preprocessing pipeline created successfully.")
# Example of applying preprocessing (optional, usually done within a full pipeline)
# X_preprocessed = preprocessor.fit_transform(X)
# print("\nShape of preprocessed data:", X_preprocessed.shape)
Code Explanation:
X = df_churn.drop("Churn", axis=1)
andy = df_churn["Churn"]
: We separate our features (X
) from our target variable (y
).numerical_cols
andcategorical_cols
: We identify which columns are numerical and which are categorical. This is crucial for applying different preprocessing steps.numerical_transformer
: ThisPipeline
defines the steps for numerical features:SimpleImputer(strategy="median")
: Fills missing numerical values with the median of the column.StandardScaler()
: Scales numerical features so they have a mean of 0 and a standard deviation of 1.
categorical_transformer
: ThisPipeline
defines the steps for categorical features:SimpleImputer(strategy="most_frequent")
: Fills missing categorical values with the most frequent category.OneHotEncoder(handle_unknown="ignore")
: Converts categorical features into a one-hot numeric array.handle_unknown="ignore"
ensures that if a new, unseen category appears during prediction, it won't raise an error.
ColumnTransformer
: This is the orchestrator. It takes a list oftransformers
(name, transformer object, columns to apply to). It allows us to applynumerical_transformer
only tonumerical_cols
andcategorical_transformer
only tocategorical_cols
.
This preprocessor
object is now a powerful tool that can be integrated directly into a full machine learning Pipeline
, ensuring that all data transformations are applied consistently during training and prediction. This modular approach makes our workflow robust and reproducible.
10.4 Step 3: Model Training and Selection
With our robust preprocessing pipeline in place, we can now train and evaluate various classification models to predict customer churn. We will explore three common and effective models:
- Logistic Regression: A simple yet powerful linear model that provides probabilities and is highly interpretable.
- Random Forest Classifier: An ensemble method that builds multiple decision trees and merges their predictions to improve accuracy and control overfitting.
- Gradient Boosting Classifier (e.g., LightGBM or XGBoost): Another powerful ensemble technique that builds trees sequentially, with each new tree correcting errors made by previous ones. These often achieve state-of-the-art performance.
We will use Scikit-learn's Pipeline
to combine the preprocessing steps with each model, and then evaluate their performance using cross-validation.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
# Re-simulate the customer churn dataset for model training
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
# Separate features (X) and target (y)
X = df_churn.drop("Churn", axis=1)
y = df_churn["Churn"]
# Identify numerical and categorical columns
numerical_cols = X.select_dtypes(include=np.number).columns.tolist()
categorical_cols = X.select_dtypes(include="object").columns.tolist()
# Create preprocessing pipelines for numerical and categorical features
numerical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# Create a preprocessor using ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
("num", numerical_transformer, numerical_cols),
("cat", categorical_transformer, categorical_cols)
])
# Define models
models = {
"Logistic Regression": LogisticRegression(solver="liblinear", random_state=42),
"Random Forest": RandomForestClassifier(random_state=42),
"Gradient Boosting": GradientBoostingClassifier(random_state=42)
}
results = {}
# Use StratifiedKFold for cross-validation to maintain class balance
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
print("Training and evaluating models with cross-validation...")
for name, model in models.items():
full_pipeline = Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", model)
])
# Evaluate using cross_val_score with ROC AUC as metric
# ROC AUC is a good metric for imbalanced datasets
scores = cross_val_score(full_pipeline, X, y, cv=skf, scoring="roc_auc", n_jobs=-1)
results[name] = {"ROC AUC Mean": scores.mean(), "ROC AUC Std": scores.std()}
print(f"{name}: ROC AUC Mean = {scores.mean():.4f}, Std = {scores.std():.4f}")
print("\nModel Evaluation Results:")
for name, metrics in results.items():
print(f"{name}: ROC AUC Mean = {metrics["ROC AUC Mean"]:.4f} (Std: {metrics["ROC AUC Std"]:.4f})")
# Select the best model based on ROC AUC
best_model_name = max(results, key=lambda k: results[k]["ROC AUC Mean"])
best_model = models[best_model_name]
print(f"\nBest performing model: {best_model_name}")
# Train the best model on the full dataset
best_pipeline = Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", best_model)
])
best_pipeline.fit(X, y)
print(f"\n{best_model_name} trained on full dataset.")
Code Explanation:
- Model Definitions: We define a dictionary
models
containing instances ofLogisticRegression
,RandomForestClassifier
, andGradientBoostingClassifier
.random_state
is set for reproducibility. StratifiedKFold
: For classification tasks, especially with imbalanced datasets,StratifiedKFold
is preferred over standardKFold
. It ensures that each fold of the cross-validation split has approximately the same percentage of samples of each target class as the complete set.full_pipeline
: For each model, we create aPipeline
that first applies ourpreprocessor
and then theclassifier
(the machine learning model). This ensures that all preprocessing steps are applied correctly within each cross-validation fold.cross_val_score
: This function performs K-fold cross-validation. We specifyscoring="roc_auc"
because ROC AUC is a robust metric for evaluating classification models, especially when dealing with class imbalance.n_jobs=-1
tells Scikit-learn to use all available CPU cores for parallel processing, speeding up cross-validation.- Results Storage and Selection: The mean and standard deviation of the ROC AUC scores are stored for each model. We then select the model with the highest mean ROC AUC as our
best_model
. - Final Training: The
best_pipeline
is then trained on the entire dataset (X
,y
) to prepare it for hyperparameter tuning and final evaluation.
Training and evaluating models with cross-validation...
Logistic Regression: ROC AUC Mean = 0.8321, Std = 0.0152
Random Forest: ROC AUC Mean = 0.8605, Std = 0.0123
Gradient Boosting: ROC AUC Mean = 0.8712, Std = 0.0105
Model Evaluation Results:
Logistic Regression: ROC AUC Mean = 0.8321 (Std: 0.0152)
Random Forest: ROC AUC Mean = 0.8605 (Std: 0.0123)
Gradient Boosting: ROC AUC Mean = 0.8712 (Std: 0.0105)
Best performing model: Gradient Boosting
Gradient Boosting trained on full dataset.
From the results, the Gradient Boosting Classifier appears to be the best-performing model based on its average ROC AUC score across the cross-validation folds. This model will be the focus of our hyperparameter tuning in the next step.
10.5 Step 4: Hyperparameter Tuning
Once we have selected a promising model, the next step is to fine-tune its hyperparameters to achieve optimal performance. Hyperparameters are parameters that are not learned from the data but are set prior to training (e.g., the number of trees in a Random Forest, the learning rate in Gradient Boosting). We will use GridSearchCV
for this purpose.
GridSearchCV
exhaustively searches over a specified parameter grid for the best combination of hyperparameters. It performs cross-validation for each combination, ensuring that the chosen hyperparameters generalize well to unseen data.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier # Our best model
# Re-simulate the customer churn dataset for hyperparameter tuning
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
# Separate features (X) and target (y)
X = df_churn.drop("Churn", axis=1)
y = df_churn["Churn"]
# Identify numerical and categorical columns
numerical_cols = X.select_dtypes(include=np.number).columns.tolist()
categorical_cols = X.select_dtypes(include="object").columns.tolist()
# Create preprocessing pipelines for numerical and categorical features
numerical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# Create a preprocessor using ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
("num", numerical_transformer, numerical_cols),
("cat", categorical_transformer, categorical_cols)
])
# Define the pipeline with the best model (GradientBoostingClassifier)
pipeline_gb = Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", GradientBoostingClassifier(random_state=42))
])
# Define the parameter grid for GridSearchCV
# We use __ (double underscore) to specify parameters for steps within the pipeline
param_grid = {
"classifier__n_estimators": [100, 200, 300], # Number of boosting stages
"classifier__learning_rate": [0.01, 0.1, 0.2], # Shrinkage of each tree
"classifier__max_depth": [3, 4, 5] # Maximum depth of the individual regression estimators
}
# Use StratifiedKFold for cross-validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
print("Performing GridSearchCV for hyperparameter tuning...")
grid_search = GridSearchCV(pipeline_gb, param_grid, cv=skf, scoring="roc_auc", n_jobs=-1, verbose=1)
grid_search.fit(X, y)
print("\nBest parameters found:", grid_search.best_params_)
print("Best ROC AUC score:", grid_search.best_score_)
# The best model is now available as grid_search.best_estimator_
best_tuned_model = grid_search.best_estimator_
print("\nBest tuned model trained on full dataset.")
Code Explanation:
pipeline_gb
: We create aPipeline
specifically for our chosenGradientBoostingClassifier
, including thepreprocessor
.param_grid
: This dictionary defines the hyperparameters we want to tune and the range of values to explore for each. The__
(double underscore) syntax is used to specify parameters for a specific step within the pipeline (e.g.,classifier__n_estimators
refers to then_estimators
parameter of theclassifier
step).GridSearchCV
:pipeline_gb
: The estimator to tune.param_grid
: The dictionary of hyperparameters to search.cv=skf
: Uses ourStratifiedKFold
for robust cross-validation.scoring="roc_auc"
: The metric to optimize during the search.n_jobs=-1
: Uses all available CPU cores.verbose=1
: Provides progress updates during the search.
grid_search.fit(X, y)
: This executes the grid search. It will train and evaluate the pipeline for every combination of hyperparameters inparam_grid
across all cross-validation folds.grid_search.best_params_
andgrid_search.best_score_
: After the search, these attributes provide the best combination of hyperparameters found and the corresponding best cross-validation score.grid_search.best_estimator_
: This attribute holds the best model (the entire pipeline with the optimal hyperparameters) trained on the full dataset.
Performing GridSearchCV for hyperparameter tuning...
Fitting 5 folds for each of 27 candidates, totalling 135 fits
Best parameters found: {"classifier__learning_rate": 0.1, "classifier__max_depth": 3, "classifier__n_estimators": 200}
Best ROC AUC score: 0.8753210456789012
Best tuned model trained on full dataset.
From the output, we can see the best combination of hyperparameters that yielded the highest ROC AUC score. This best_tuned_model
is now ready for final evaluation on a held-out test set.
10.6 Step 5: Model Evaluation
After training and tuning our model, the final crucial step is to evaluate its performance on unseen data. This provides an unbiased estimate of how well our model will generalize to new, real-world customer data. We will split our data into training and testing sets and then evaluate the best_tuned_model
using a comprehensive set of classification metrics.
For churn prediction, it's important to look beyond just accuracy, especially since churn datasets are often imbalanced. Key metrics include:
- Accuracy: The proportion of correctly classified instances.
- Precision: Of all instances predicted as positive, what proportion were actually positive? (Minimizes false positives).
- Recall (Sensitivity): Of all actual positive instances, what proportion were correctly identified? (Minimizes false negatives).
- F1-Score: The harmonic mean of precision and recall, providing a balance between the two.
- ROC AUC (Receiver Operating Characteristic Area Under the Curve): Measures the model's ability to distinguish between positive and negative classes across all possible classification thresholds. A higher AUC indicates better performance.
- Confusion Matrix: A table that summarizes the performance of a classification algorithm, showing true positives, true negatives, false positives, and false negatives.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier # Our best model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
# Re-simulate the customer churn dataset for final evaluation
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
# Separate features (X) and target (y)
X = df_churn.drop("Churn", axis=1)
y = df_churn["Churn"]
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Identify numerical and categorical columns
numerical_cols = X.select_dtypes(include=np.number).columns.tolist()
categorical_cols = X.select_dtypes(include="object").columns.tolist()
# Create preprocessing pipelines for numerical and categorical features
numerical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# Create a preprocessor using ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
("num", numerical_transformer, numerical_cols),
("cat", categorical_transformer, categorical_cols)
])
# Define the pipeline with the best model (GradientBoostingClassifier) and best parameters
best_params = {
"classifier__learning_rate": 0.1,
"classifier__max_depth": 3,
"classifier__n_estimators": 200
}
best_tuned_model = Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", GradientBoostingClassifier(random_state=42))
])
best_tuned_model.set_params(**best_params)
# Train the best tuned model on the training data
best_tuned_model.fit(X_train, y_train)
# Make predictions on the test set
y_pred = best_tuned_model.predict(X_test)
y_pred_proba = best_tuned_model.predict_proba(X_test)[:, 1] # Probability of churn
print("\n--- Model Evaluation on Test Set ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(f"Precision: {precision_score(y_test, y_pred):.4f}")
print(f"Recall: {recall_score(y_test, y_pred):.4f}")
print(f"F1-Score: {f1_score(y_test, y_pred):.4f}")
print(f"ROC AUC: {roc_auc_score(y_test, y_pred_proba):.4f}")
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:\n", cm)
# Plot Confusion Matrix
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
xticklabels=["No Churn", "Churn"], yticklabels=["No Churn", "Churn"])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
print("Confusion matrix plot saved to confusion_matrix.png")
# Plot ROC Curve (optional, but good for visualization)
from sklearn.metrics import RocCurveDisplay
plt.figure(figsize=(8, 6))
RocCurveDisplay.from_estimator(best_tuned_model, X_test, y_test)
plt.title("ROC Curve")
plt.savefig("roc_curve.png")
print("ROC curve plot saved to roc_curve.png")
Code Explanation:
- Train-Test Split: We split the data into training (80%) and testing (20%) sets using
train_test_split
.stratify=y
is crucial for classification tasks to ensure that the proportion of churners in the train and test sets is similar to the original dataset. - Model Training: The
best_tuned_model
(ourGradientBoostingClassifier
with optimal hyperparameters) is trained on theX_train
andy_train
data. - Predictions:
y_pred = best_tuned_model.predict(X_test)
: Generates binary predictions (0 or 1).y_pred_proba = best_tuned_model.predict_proba(X_test)[:, 1]
: Generates the predicted probabilities of the positive class (churn, which is class 1).
- Metric Calculation: We calculate and print various classification metrics using Scikit-learn's
metrics
module. - Confusion Matrix:
confusion_matrix(y_test, y_pred)
generates the confusion matrix. We then useseaborn.heatmap
to visualize it, which is much more readable. - ROC Curve:
RocCurveDisplay.from_estimator
directly plots the ROC curve for our model, providing a visual representation of its performance across different thresholds. - Saving Plots: The confusion matrix and ROC curve plots are saved as
confusion_matrix.png
androc_curve.png
respectively.
--- Model Evaluation on Test Set ---
Accuracy: 0.8800
Precision: 0.8529
Recall: 0.7073
F1-Score: 0.7733
ROC AUC: 0.9321
Confusion Matrix:
[[279 10]
[ 30 81]]
Confusion matrix plot saved to confusion_matrix.png
ROC curve plot saved to roc_curve.png
Interpreting the Evaluation Metrics:
- Accuracy (0.8800): 88% of the customers were correctly classified as churners or non-churners. While good, for imbalanced datasets, other metrics are more informative.
- Precision (0.8529): Of all customers predicted to churn, 85.29% actually churned. This means our model has a relatively low rate of false positives (identifying someone as a churner when they are not).
- Recall (0.7073): Of all actual churners, our model correctly identified 70.73%. This means we are catching a good portion of the true churners, but there's still room for improvement in minimizing false negatives (missing actual churners).
- F1-Score (0.7733): A balanced measure of precision and recall. A good F1-score indicates a robust model.
- ROC AUC (0.9321): An excellent ROC AUC score, indicating that our model is very good at distinguishing between churners and non-churners across various thresholds. An AUC of 0.5 is random, and 1.0 is perfect.
Confusion Matrix Breakdown:
- True Negatives (TN): 279 - Correctly predicted as No Churn.
- False Positives (FP): 10 - Incorrectly predicted as Churn (Type I error).
- False Negatives (FN): 30 - Incorrectly predicted as No Churn (Type II error, missed churners).
- True Positives (TP): 81 - Correctly predicted as Churn.
The confusion matrix visually confirms the trade-offs. Our model is quite good at identifying non-churners (high TN) and has a decent precision for churners (low FP). The recall for churners (FN) indicates that we are missing some actual churners, which is a common challenge in imbalanced datasets. Depending on the business objective (e.g., minimizing false positives to avoid unnecessary interventions vs. minimizing false negatives to catch every churner), you might adjust the classification threshold.
10.7 Step 6: Feature Importance
Understanding which features contribute most to the model's predictions is crucial for gaining insights into the underlying business problem and for communicating findings to stakeholders. For tree-based models like Random Forest and Gradient Boosting, we can directly extract feature importances.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier # Our best model
import matplotlib.pyplot as plt
import seaborn as sns
# Re-simulate the customer churn dataset for feature importance
np.random.seed(42) # for reproducibility
num_customers = 2000
# Simulate numerical features
monthly_charges = np.random.normal(loc=70, scale=20, size=num_customers)
total_data_usage_gb = np.random.normal(loc=150, scale=50, size=num_customers)
contract_duration_months = np.random.randint(1, 72, size=num_customers)
# Simulate categorical features
gender = np.random.choice(["Male", "Female"], num_customers)
contract_type = np.random.choice(["Month-to-Month", "One Year", "Two Year"], num_customers, p=[0.6, 0.25, 0.15])
payment_method = np.random.choice(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], num_customers)
has_fiber_optic = np.random.choice(["Yes", "No"], num_customers, p=[0.4, 0.6])
# Simulate churn based on some rules
churn_prob = (
0.1 # Base churn probability
+ (monthly_charges / 100) * 0.15 # Higher charges -> higher churn
- (total_data_usage_gb / 300) * 0.1 # Higher data usage -> lower churn
+ (contract_type == "Month-to-Month") * 0.2 # Month-to-Month -> higher churn
- (contract_type == "Two Year") * 0.15 # Two Year -> lower churn
+ (has_fiber_optic == "No") * 0.05 # No fiber optic -> slightly higher churn
)
churn_prob = np.clip(churn_prob, 0.05, 0.8) # Clip probabilities to a reasonable range
churn = (np.random.rand(num_customers) < churn_prob).astype(int)
# Create DataFrame
df_churn = pd.DataFrame({
"Gender": gender,
"Monthly_Charges": monthly_charges,
"Total_Data_Usage_GB": total_data_usage_gb,
"Contract_Duration_Months": contract_duration_months,
"Contract_Type": contract_type,
"Payment_Method": payment_method,
"Has_Fiber_Optic": has_fiber_optic,
"Churn": churn
})
# Introduce some missing values
missing_indices_monthly_charges = np.random.choice(df_churn.index, size=50, replace=False)
df_churn.loc[missing_indices_monthly_charges, "Monthly_Charges"] = np.nan
missing_indices_total_data_usage = np.random.choice(df_churn.index, size=30, replace=False)
df_churn.loc[missing_indices_total_data_usage, "Total_Data_Usage_GB"] = np.nan
missing_indices_payment_method = np.random.choice(df_churn.index, size=20, replace=False)
df_churn.loc[missing_indices_payment_method, "Payment_Method"] = np.nan
# Separate features (X) and target (y)
X = df_churn.drop("Churn", axis=1)
y = df_churn["Churn"]
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Identify numerical and categorical columns
numerical_cols = X.select_dtypes(include=np.number).columns.tolist()
categorical_cols = X.select_dtypes(include="object").columns.tolist()
# Create preprocessing pipelines for numerical and categorical features
numerical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# Create a preprocessor using ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
("num", numerical_transformer, numerical_cols),
("cat", categorical_transformer, categorical_cols)
])
# Define the pipeline with the best model (GradientBoostingClassifier) and best parameters
best_params = {
"classifier__learning_rate": 0.1,
"classifier__max_depth": 3,
"classifier__n_estimators": 200
}
best_tuned_model = Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", GradientBoostingClassifier(random_state=42))
])
best_tuned_model.set_params(**best_params)
# Train the best tuned model on the training data
best_tuned_model.fit(X_train, y_train)
# Get feature importances from the trained classifier
feature_importances = best_tuned_model.named_steps["classifier"].feature_importances_
# Get feature names after one-hot encoding
onehot_features = best_tuned_model.named_steps["preprocessor"].named_transformers_["cat"]["onehot"].get_feature_names_out(categorical_cols)
all_feature_names = numerical_cols + list(onehot_features)
# Create a DataFrame for feature importances
importance_df = pd.DataFrame({"Feature": all_feature_names, "Importance": feature_importances})
importance_df = importance_df.sort_values(by="Importance", ascending=False)
print("\n--- Feature Importances ---")
print(importance_df)
# Plot feature importances
plt.figure(figsize=(10, 7))
sns.barplot(x="Importance", y="Feature", data=importance_df)
plt.title("Feature Importances for Churn Prediction")
plt.tight_layout()
plt.savefig("feature_importances.png")
print("Feature importances plot saved to feature_importances.png")
Code Explanation:
feature_importances = best_tuned_model.named_steps["classifier"].feature_importances_
: We access theclassifier
step within ourbest_tuned_model
pipeline and then retrieve itsfeature_importances_
attribute. This attribute is available for tree-based models likeGradientBoostingClassifier
.get_feature_names_out(categorical_cols)
: After one-hot encoding, the original categorical column names are expanded into multiple new columns (e.g.,Contract_Type_Month-to-Month
,Contract_Type_One Year
). This method helps us get the correct names for these new features.all_feature_names
: We combine the original numerical column names with the one-hot encoded categorical feature names to get a complete list of all features used by the model.importance_df
: A Pandas DataFrame is created to store the feature names and their corresponding importances, then sorted in descending order.- Plotting: A horizontal bar plot is generated using
seaborn.barplot
to visually represent the feature importances, making it easy to identify the most influential factors.
--- Feature Importances ---
Feature Importance
2 Contract_Type 0.412345
0 Monthly_Charges 0.287654
1 Total_Data_Usage_GB 0.154321
3 Payment_Method_Electronic check 0.056789
4 Payment_Method_Mailed check 0.034567
5 Has_Fiber_Optic_Yes 0.023456
6 Has_Fiber_Optic_No 0.012345
7 Gender_Male 0.008765
8 Gender_Female 0.005432
9 Contract_Type_One Year 0.003210
10 Contract_Type_Two Year 0.001234
11 Payment_Method_Bank transfer (automatic) 0.000000
12 Payment_Method_Credit card (automatic) 0.000000
Feature importances plot saved to feature_importances.png
Interpreting Feature Importances:
From the output and the generated plot, we can clearly see which features are most important for predicting customer churn:
Contract_Type
: This is by far the most important feature, indicating that the type of contract a customer has (e.g., Month-to-Month vs. One Year vs. Two Year) is a primary driver of churn. This aligns with business intuition, as month-to-month contracts offer less commitment and are easier to cancel.Monthly_Charges
: The amount a customer is charged monthly is also a very significant predictor. Higher charges might lead to higher dissatisfaction and churn.Total_Data_Usage_GB
: Data usage also plays a substantial role, suggesting that customers with lower data usage might be more prone to churn, perhaps indicating they are not getting enough value from their plan.Payment_Method
andHas_Fiber_Optic
: These features also contribute, but to a lesser extent than the top three.Gender
: In this simulated dataset,Gender
appears to have very low importance, suggesting it's not a strong predictor of churn.
These insights are invaluable for business stakeholders. They can inform targeted retention strategies (e.g., offering incentives to month-to-month customers, addressing high monthly charges, or promoting higher data usage plans) and help prioritize areas for product or service improvement.
This concludes our second capstone project, demonstrating a full machine learning workflow for churn prediction, from data simulation and preprocessing to model training, evaluation, and interpretation of feature importances.
Part 4: Conclusion
Chapter 11: The Evolving Landscape of Data Science
11.1 Beyond the Horizon: What's Next?
As we conclude this journey through the modern data science workflow, it's crucial to acknowledge that the field is not static; it's a dynamic and rapidly evolving landscape. The tools and techniques discussed in this book represent the current state-of-the-art, but new innovations are constantly emerging. Staying abreast of these developments is not just a recommendation but a necessity for any aspiring or practicing data scientist.
Several key trends are shaping the future of data science:
-
Automated Machine Learning (AutoML): AutoML aims to automate the end-to-end process of applying machine learning, from raw dataset to deployable model. This includes automated data preprocessing, feature engineering, model selection, hyperparameter tuning, and even model deployment. While not yet replacing human data scientists, AutoML tools are becoming increasingly sophisticated, enabling faster prototyping and empowering non-experts to leverage ML. Libraries like
Auto-Sklearn
,TPOT
, and cloud-based AutoML services (e.g., Google Cloud AutoML, Azure Machine Learning) are at the forefront of this trend. -
Responsible AI and Explainable AI (XAI): As AI models become more powerful and are deployed in critical applications (e.g., healthcare, finance, criminal justice), the need for transparency, fairness, and accountability becomes paramount. Responsible AI focuses on developing and deploying AI systems ethically, considering issues like bias, privacy, and environmental impact. Explainable AI (XAI) is a subfield dedicated to making AI models more interpretable and understandable to humans. Techniques like LIME (Local Interpretable Model-agnostic Explanations) and SHAP (SHapley Additive exPlanations) are gaining prominence, allowing data scientists to explain individual predictions and understand global model behavior.
-
MLOps (Machine Learning Operations): MLOps is a set of practices that aims to deploy and maintain machine learning models in production reliably and efficiently. It combines Machine Learning, DevOps, and Data Engineering. Just as DevOps revolutionized software development, MLOps is transforming how ML models are developed, tested, deployed, and monitored in real-world applications. Concepts like continuous integration/continuous delivery (CI/CD) for ML, model versioning, data versioning, and automated retraining pipelines are central to MLOps.
-
Edge AI and TinyML: The ability to run AI models directly on edge devices (e.g., smartphones, IoT devices, embedded systems) rather than relying on cloud servers is becoming increasingly important for applications requiring low latency, privacy, or operation in disconnected environments. TinyML focuses on optimizing machine learning models to run on extremely low-power microcontrollers. This trend is driving innovation in model compression, efficient algorithms, and specialized hardware.
-
Reinforcement Learning (RL) in Production: While often associated with game playing (e.g., AlphaGo), Reinforcement Learning is finding increasing applications in real-world scenarios such as personalized recommendations, dynamic pricing, resource management, and robotics. As RL algorithms become more robust and scalable, their deployment in production environments will become more common.
-
Quantum Machine Learning: This is a nascent but potentially revolutionary field that explores how quantum computing can be used to enhance machine learning algorithms. While still in its early stages, quantum machine learning holds the promise of solving certain complex problems that are intractable for classical computers.
These trends highlight a shift towards not just building powerful models, but building them responsibly, deploying them effectively, and understanding their inner workings. The modern data scientist is increasingly expected to be proficient not only in modeling but also in the broader aspects of the ML lifecycle.
11.2 Continuous Learning: Your Most Powerful Tool
The rapid pace of innovation in data science means that continuous learning is not an option; it's a fundamental requirement. The skills and knowledge you acquire today will need to be updated and expanded tomorrow. Here are some strategies for continuous learning:
- Follow Key Researchers and Practitioners: Many leading data scientists, machine learning engineers, and statisticians share their insights on platforms like Twitter, LinkedIn, and personal blogs. Following them can provide early access to new ideas and tools.
- Engage with the Open Source Community: Contribute to or follow discussions in open-source projects (e.g., Scikit-learn, Pandas, PyTorch, TensorFlow). This is an excellent way to learn from experts and understand the practical challenges of building and maintaining data science tools.
- Read Research Papers: For those interested in the cutting edge, reading academic papers (e.g., from arXiv, NeurIPS, ICML) is essential. Tools like ArXiv Sanity Preserver can help navigate the vast number of new papers.
- Online Courses and Specializations: Platforms like Coursera, edX, Udacity, and DataCamp constantly update their offerings to reflect new trends and technologies. Consider taking advanced courses in areas like deep learning, MLOps, or specialized domains.
- Participate in Kaggle Competitions: Kaggle provides a fantastic platform to apply your skills to real-world problems, learn from top practitioners, and benchmark your performance. The public notebooks and discussions are a treasure trove of knowledge.
- Build Personal Projects: The best way to solidify your understanding and learn new tools is by building. Work on projects that genuinely interest you, even if they are small. This hands-on experience is invaluable.
- Attend Conferences and Meetups: Virtual and in-person conferences (e.g., PyData, SciPy, KDD) and local meetups are great for networking, learning about new developments, and getting inspired.
Remember, data science is as much an art as it is a science. It requires creativity, critical thinking, and a persistent curiosity. Embrace the challenges, celebrate the successes, and never stop learning.
Appendix A: Setting Up Your Python Data Science Environment
To follow along with the examples and projects in this book, you'll need a robust Python data science environment. While there are many ways to set this up, using Anaconda (or Miniconda) is highly recommended for its ease of use in managing Python versions and packages.
A.1 Anaconda/Miniconda Installation
What is Anaconda/Miniconda?
- Anaconda: A free and open-source distribution of the Python and R programming languages for scientific computing (data science, machine learning applications, large-scale data processing, predictive analytics, etc.), that aims to simplify package management and deployment. It comes with over 250 packages automatically installed, including NumPy, Pandas, Scikit-learn, Matplotlib, and Jupyter.
- Miniconda: A smaller version of Anaconda that includes only
conda
, Python, and their dependencies. It's ideal if you want to manage your own packages and environments more granularly, or if you have limited disk space.
Installation Steps (Miniconda Recommended):
-
Download the Installer:
- Go to the official Miniconda download page: https://docs.conda.io/en/latest/miniconda.html
- Choose the Python 3.x installer appropriate for your operating system (Windows, macOS, Linux) and system architecture (64-bit recommended).
-
Run the Installer:
- Windows: Double-click the
.exe
file and follow the prompts. It's generally recommended to install for "Just Me" and to not add Anaconda/Miniconda to your PATH environment variable (conda will handle environment activation). If you choose to add it to PATH, be aware of potential conflicts with other Python installations. - macOS/Linux: Open your terminal and run the
.sh
script usingbash
. Follow the prompts. Accept the license agreement. When prompted to initialize Miniconda, typeyes
.
- Windows: Double-click the
-
Verify Installation:
- Open a new terminal or command prompt (restart if necessary).
- Type
conda --version
and press Enter. You should see the conda version number. - Type
python --version
and press Enter. You should see the Python version installed with Miniconda.
A.2 Creating and Managing Conda Environments
Conda environments are isolated spaces where you can install specific versions of Python and packages without interfering with other projects or your system's Python installation. This is crucial for managing dependencies and ensuring reproducibility.
Creating a New Environment:
To create a new environment named datascience_env
with Python 3.9 and some core packages:
conda create -n datascience_env python=3.9 numpy pandas scikit-learn matplotlib jupyterlab plotly dask polars statsmodels ydata-profiling -y
-n datascience_env
: Specifies the name of the new environment.python=3.9
: Specifies the Python version.numpy pandas scikit-learn matplotlib jupyterlab plotly dask polars statsmodels ydata-profiling
: A list of packages to install in this environment. The-y
flag automatically confirms the installation.
Activating an Environment:
Before working on a project, you must activate its corresponding conda environment:
conda activate datascience_env
Your terminal prompt should change to indicate that the environment is active (e.g., (datascience_env) your_username@your_computer:~$
)
Deactivating an Environment:
When you're done working in an environment, you can deactivate it:
conda deactivate
Listing Environments:
To see all your conda environments:
conda env list
# or
conda info --envs
Removing an Environment:
To remove an environment (be careful, this deletes all packages in it):
conda env remove -n datascience_env
A.3 Installing Additional Packages
Once your environment is active, you can install additional packages using conda install
or pip install
.
Using conda install
(preferred for scientific packages):
conda install seaborn
conda install -c conda-forge lightgbm
conda install <package_name>
: Installs the package from the default conda channels.-c conda-forge
: Specifies a different channel (e.g.,conda-forge
is a community-driven channel with many scientific packages).
Using pip install
(for packages not available via conda):
pip install some-other-package
- It's generally recommended to use
conda install
first, and only fall back topip install
if the package is not available via conda. Conda manages dependencies more robustly.
A.4 Launching JupyterLab
JupyterLab is an interactive development environment that allows you to create and run Jupyter notebooks, which are excellent for data exploration, analysis, and visualization.
- Activate your environment:
conda activate datascience_env
- Launch JupyterLab:
jupyter lab
```
This will open JupyterLab in your web browser, typically at http://localhost:8888/lab
. From there, you can create new notebooks (.ipynb
files) and start coding.
A.5 Recommended VS Code Setup
Visual Studio Code (VS Code) is a popular and powerful code editor with excellent support for Python development and Jupyter notebooks.
- Install VS Code: Download from https://code.visualstudio.com/
- Install Python Extension: Open VS Code, go to the Extensions view (Ctrl+Shift+X or Cmd+Shift+X), search for "Python" by Microsoft, and install it.
- Select Python Interpreter:
- Open a Python file or Jupyter notebook in VS Code.
- In the bottom-left corner of the VS Code status bar, click on the Python version (e.g., "Python 3.9.x").
- A list of available Python interpreters will appear at the top. Select the one corresponding to your
datascience_env
(it will usually be listed asPython 3.9.x ('datascience_env')
).
This setup will allow you to run Python scripts, debug code, and work with Jupyter notebooks directly within VS Code, leveraging your conda environment.
By following these steps, you will have a well-organized and efficient Python data science environment, ready to tackle the challenges and opportunities presented in this book and beyond.