iBrokeTheCode's picture
refactor: Use Pipeline and ColumnTransformer for preprocessing
9995a6a
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure
from pandas import DataFrame, Series
from src.theme import custom_palette
def plot_target_distribution(df: DataFrame) -> tuple[DataFrame, Figure]:
"""
Plot the distribution of the 'TARGET' column in a DataFrame.
Args:
df (DataFrame): The input DataFrame containing the 'TARGET' column.
Returns:
DataFrame: A DataFrame containing the count and percentage of each class.
Figure: The matplotlib Figure object containing the plot.
"""
target_counts = df["TARGET"].value_counts()
target_percent = (target_counts / target_counts.sum() * 100).round(2)
# Combine into a DataFrame for clarity
target_df = target_counts.to_frame(name="Count")
target_df["Percentage"] = target_percent
fig, ax = plt.subplots(figsize=(8, 5))
sns.barplot(
data=target_df,
x="TARGET",
y="Count",
hue="TARGET",
palette=custom_palette[:2],
)
# Titles and formatting
ax.set_xlabel("Payment Difficulties (1 = Yes, 0 = No)", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
ax.grid(axis="y", linestyle="--", alpha=0.4)
fig.tight_layout()
return target_df, fig
def plot_credit_amounts(df: DataFrame) -> Figure:
"""
Plot a histogram of credit amounts.
Args:
df (DataFrame): The DataFrame containing the credit amount data.
Returns:
Figure: The matplotlib figure object containing the plot.
"""
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data=df, x="AMT_CREDIT", bins=100, kde=True, color=custom_palette[0])
ax.grid(axis="y", linestyle="--", alpha=0.5)
fig.tight_layout()
return fig
def plot_education_levels(df: DataFrame) -> tuple[DataFrame, Figure]:
"""
Plot a bar chart of education levels.
Args:
df (DataFrame): The DataFrame containing the education level data.
Returns:
DataFrame: The DataFrame containing the education level counts and percentages.
Figure: The matplotlib figure object containing the plot.
"""
education_count = (
df["NAME_EDUCATION_TYPE"].value_counts().sort_values(ascending=False)
)
education_percentage = (education_count / df.shape[0] * 100).round(2)
education_df = education_count.to_frame(name="Count")
education_df["Percentage"] = education_percentage
fig, ax = plt.subplots(figsize=(10, 6))
sns.countplot(
data=df,
y="NAME_EDUCATION_TYPE",
hue="NAME_EDUCATION_TYPE",
palette=custom_palette[:5],
)
ax.set_xlabel("Count")
ax.set_ylabel("Education Level")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return education_df, fig
def plot_occupation(df: DataFrame) -> tuple[Series, Figure]:
"""
Plot the distribution of occupations in the dataset.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Series: A Series containing the count of each occupation.
Figure: A Matplotlib Figure object containing the plot.
"""
occupation_df = df["OCCUPATION_TYPE"].value_counts(dropna=False, ascending=False)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x=occupation_df.values,
y=occupation_df.index,
hue=occupation_df.index,
legend=False,
)
ax.set_xlabel("Number of Applicants")
ax.set_ylabel("Occupation")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return occupation_df, fig
def plot_family_status(df: DataFrame) -> tuple[Series, Figure]:
"""
Plot the distribution of family statuses in the dataset.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Series: A Series containing the count of each family status.
Figure: A Matplotlib Figure object containing the plot.
"""
family_status_df = df["NAME_FAMILY_STATUS"].value_counts(
dropna=False, ascending=False
)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x=family_status_df.values,
y=family_status_df.index,
hue=family_status_df.index,
palette=custom_palette[:6],
legend=False,
)
ax.set_xlabel("Number of Applicants")
ax.set_ylabel("Family Status")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return family_status_df, fig
def plot_income_type(df: DataFrame) -> Figure:
"""
Plot the count of income types for each target group.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Figure: A Matplotlib Figure object containing the plot.
"""
fig, ax1 = plt.subplots(figsize=(10, 6))
sns.countplot(
data=df, y="NAME_INCOME_TYPE", hue="TARGET", palette=custom_palette[:2]
)
ax1.set_xlabel("Number of Applicants")
ax1.set_ylabel("Income Type")
ax1.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return fig