Source code for titanicprediction.data.analysis

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Protocol

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger

from titanicprediction.entities.core import Dataset, TrainedModel


[docs] class PlotType(Enum): HISTOGRAM = "histogram" BAR = "bar" SCATTER = "scatter" BOX = "box" HEATMAP = "heatmap" LINE = "line" PIE = "pie" KDE = "kde" VIOLIN = "violin" COUNT = "count"
[docs] class BackendType(Enum): MATPLOTLIB = "matplotlib" PLOTLY = "plotly" SEABORN = "seaborn"
[docs] @dataclass class PlotConfig: plot_type: PlotType title: str x_label: str y_label: str figsize: tuple[int, int] = (10, 6) style: str = "seaborn-v0_8" colors: list[str] = None save_format: str = "png" dpi: int = 300 font_size: int = 12 legend: bool = True grid: bool = True
[docs] @dataclass class PlotResult: figure: plt.Figure axes: plt.Axes config: PlotConfig metadata: dict[str, Any]
[docs] class IVisualizer(Protocol):
[docs] def create_plot(self, data: Any,
plot_config: PlotConfig) -> PlotResult: ...
[docs] def save_plot(self, plot: PlotResult, filename: str) -> Path: ...
[docs] def get_supported_formats(self) -> list[str]: ...
[docs] class DistributionVisualizer:
[docs] def create_plot(self, data: Any, plot_config: PlotConfig) -> PlotResult: plt.style.use(plot_config.style) fig, ax = plt.subplots(figsize=plot_config.figsize) if plot_config.plot_type == PlotType.HISTOGRAM: self._create_histogram(ax, data, plot_config) elif plot_config.plot_type == PlotType.KDE: self._create_kde(ax, data, plot_config) elif plot_config.plot_type == PlotType.VIOLIN: self._create_violin(ax, data, plot_config) self._apply_common_styling(ax, plot_config) metadata = { "data_points": len(data) if hasattr(data, "__len__") else None, "data_type": type(data).__name__, "created_at": pd.Timestamp.now().isoformat(), } return PlotResult(figure=fig, axes=ax, config=plot_config, metadata=metadata)
[docs] def _create_histogram( self, ax: plt.Axes, data: pd.Series, config: PlotConfig ) -> None: color = config.colors[0] if config.colors else "skyblue" ax.hist(data.dropna(), bins=30, alpha=0.7, color=color, edgecolor="black") ax.set_xlabel(config.x_label, fontsize=config.font_size) ax.set_ylabel("Frequency", fontsize=config.font_size)
[docs] def _create_kde(self, ax: plt.Axes, data: pd.Series, config: PlotConfig) -> None: color = config.colors[0] if config.colors else "red" data_clean = data.dropna() data_clean.plot.kde(ax=ax, color=color, linewidth=2) ax.set_xlabel(config.x_label, fontsize=config.font_size) ax.set_ylabel("Density", fontsize=config.font_size)
[docs] def _create_violin(self, ax: plt.Axes, data: pd.Series, config: PlotConfig) -> None: config.colors[0] if config.colors else "lightgreen" ax.violinplot(data.dropna(), showmeans=True, showmedians=True) ax.set_ylabel(config.y_label, fontsize=config.font_size)
[docs] def _apply_common_styling(self, ax: plt.Axes, config: PlotConfig) -> None: ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") if config.grid: ax.grid(True, alpha=0.3) ax.tick_params(labelsize=config.font_size - 2)
[docs] class CorrelationVisualizer:
[docs] def _apply_common_styling(self, ax: plt.Axes, config: PlotConfig) -> None: ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") if config.grid: ax.grid(True, alpha=0.3) ax.tick_params(labelsize=config.font_size - 2)
[docs] def create_plot(self, data: pd.DataFrame, plot_config: PlotConfig) -> PlotResult: plt.style.use(plot_config.style) fig, ax = plt.subplots(figsize=plot_config.figsize) if plot_config.plot_type == PlotType.HEATMAP: self._create_correlation_heatmap(ax, data, plot_config) self._apply_common_styling(ax, plot_config) metadata = { "correlation_matrix_shape": data.shape, "features": list(data.columns), "created_at": pd.Timestamp.now().isoformat(), } return PlotResult(figure=fig, axes=ax, config=plot_config, metadata=metadata)
[docs] def _create_correlation_heatmap( self, ax: plt.Axes, corr_matrix: pd.DataFrame, config: PlotConfig ) -> None: mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) sns.heatmap( corr_matrix, mask=mask, annot=True, fmt=".2f", cmap="coolwarm", center=0, square=True, ax=ax, cbar_kws={"shrink": 0.8}, ) ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold")
[docs] class TrainingVisualizer:
[docs] def create_plot( self, loss_history: list[float], plot_config: PlotConfig ) -> PlotResult: plt.style.use(plot_config.style) fig, ax = plt.subplots(figsize=plot_config.figsize) if plot_config.plot_type == PlotType.LINE: self._create_training_curve(ax, loss_history, plot_config) self._apply_common_styling(ax, plot_config) metadata = { "training_epochs": len(loss_history), "final_loss": loss_history[-1] if loss_history else None, "min_loss": min(loss_history) if loss_history else None, "created_at": pd.Timestamp.now().isoformat(), } return PlotResult(figure=fig, axes=ax, config=plot_config, metadata=metadata)
[docs] def _create_training_curve( self, ax: plt.Axes, loss_history: list[float], config: PlotConfig ) -> None: color = config.colors[0] if config.colors else "blue" ax.plot(range(len(loss_history)), loss_history, color=color, linewidth=2) ax.set_xlabel("Epoch", fontsize=config.font_size) ax.set_ylabel("Loss", fontsize=config.font_size) ax.set_yscale("log")
[docs] class FeatureAnalysisVisualizer:
[docs] def create_plot(self, dataset: Dataset, plot_config: PlotConfig) -> PlotResult: plt.style.use(plot_config.style) fig, ax = plt.subplots(figsize=plot_config.figsize) if plot_config.plot_type == PlotType.BAR: if "target" in str(type(dataset)).lower(): self._create_survival_by_feature(ax, dataset, plot_config) else: self._create_feature_distribution(ax, dataset, plot_config) elif plot_config.plot_type == PlotType.BOX: self._create_box_plot(ax, dataset, plot_config) elif plot_config.plot_type == PlotType.SCATTER: self._create_scatter_plot(ax, dataset, plot_config) elif plot_config.plot_type == PlotType.PIE: self._create_pie_chart(ax, dataset, plot_config) elif plot_config.plot_type == PlotType.COUNT: self._create_count_plot(ax, dataset, plot_config) self._apply_common_styling(ax, plot_config) metadata = { "dataset_shape": dataset.features.shape, "feature_names": dataset.feature_names[:5], "created_at": pd.Timestamp.now().isoformat(), } return PlotResult(figure=fig, axes=ax, config=plot_config, metadata=metadata)
[docs] def _create_survival_by_feature( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: feature_name = config.x_label.lower() if feature_name in dataset.features.columns: survival_by_feature = dataset.features.groupby(feature_name)[ dataset.target_name ].agg(["mean", "count"]) colors = config.colors if config.colors else ["#ff9999", "#66b3ff"] x_pos = np.arange(len(survival_by_feature)) ax.bar( x_pos, survival_by_feature["mean"] * 100, color=colors[0], alpha=0.7, edgecolor="black", ) ax.set_xlabel(config.x_label, fontsize=config.font_size) ax.set_ylabel("Survival Rate (%)", fontsize=config.font_size) ax.set_xticks(x_pos) ax.set_xticklabels([str(x) for x in survival_by_feature.index]) for i, (_idx, row) in enumerate(survival_by_feature.iterrows()): ax.text( i, row["mean"] * 100 + 1, f"{row['mean'] * 100:.1f}%", ha="center", va="bottom", fontsize=config.font_size - 2, )
[docs] def _create_feature_distribution( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: feature_name = config.x_label.lower() if feature_name in dataset.features.columns: feature_data = dataset.features[feature_name].value_counts( ).sort_index() colors = ( config.colors if config.colors else plt.cm.Set3(np.linspace(0, 1, len(feature_data))) ) ax.bar( range(len(feature_data)), feature_data.values, color=colors, alpha=0.7, edgecolor="black", ) ax.set_xlabel(config.x_label, fontsize=config.font_size) ax.set_ylabel("Count", fontsize=config.font_size) ax.set_xticks(range(len(feature_data))) ax.set_xticklabels([str(x) for x in feature_data.index], rotation=45)
[docs] def _create_box_plot( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: feature_name = config.x_label.lower() target_name = config.y_label.lower() if ( feature_name in dataset.features.columns and target_name in dataset.features.columns ): data_to_plot = [] groups = dataset.features[target_name].unique() for group in sorted(groups): group_data = dataset.features[dataset.features[target_name] == group][ feature_name ].dropna() data_to_plot.append(group_data) ax.boxplot(data_to_plot, labels=[str(g) for g in groups]) ax.set_xlabel(target_name, fontsize=config.font_size) ax.set_ylabel(feature_name, fontsize=config.font_size)
[docs] def _create_scatter_plot( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: x_feature = config.x_label.lower() y_feature = config.y_label.lower() if ( x_feature in dataset.features.columns and y_feature in dataset.features.columns ): colors = config.colors if config.colors else ["red", "blue"] if dataset.target is not None: for target_val in [0, 1]: mask = dataset.target == target_val ax.scatter( dataset.features[mask][x_feature], dataset.features[mask][y_feature], c=colors[target_val], alpha=0.6, label=f"Survived={target_val}", ) if config.legend: ax.legend() else: ax.scatter( dataset.features[x_feature], dataset.features[y_feature], alpha=0.6 ) ax.set_xlabel(x_feature, fontsize=config.font_size) ax.set_ylabel(y_feature, fontsize=config.font_size)
[docs] def _create_pie_chart( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: feature_name = config.x_label.lower() if feature_name in dataset.features.columns: value_counts = dataset.features[feature_name].value_counts() colors = ( config.colors if config.colors else plt.cm.Pastel1(range(len(value_counts))) ) _wedges, _texts, autotexts = ax.pie( value_counts.values, labels=value_counts.index, autopct="%1.1f%%", colors=colors, startangle=90, ) for autotext in autotexts: autotext.set_color("black") autotext.set_fontsize(config.font_size - 2)
[docs] def _create_count_plot( self, ax: plt.Axes, dataset: Dataset, config: PlotConfig ) -> None: feature_name = config.x_label.lower() if feature_name in dataset.features.columns and dataset.target is not None: cross_tab = pd.crosstab( dataset.features[feature_name], dataset.target) colors = config.colors if config.colors else ["#ff6b6b", "#4ecdc4"] bar_width = 0.35 x_pos = np.arange(len(cross_tab)) for i, target_val in enumerate([0, 1]): ax.bar( x_pos + i * bar_width, cross_tab[target_val], bar_width, label=f"Survived={target_val}", color=colors[i], alpha=0.7, ) ax.set_xlabel(feature_name, fontsize=config.font_size) ax.set_ylabel("Count", fontsize=config.font_size) ax.set_xticks(x_pos + bar_width / 2) ax.set_xticklabels([str(x) for x in cross_tab.index]) if config.legend: ax.legend()
[docs] def _apply_common_styling(self, ax: plt.Axes, config: PlotConfig) -> None: ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") if config.grid: ax.grid(True, alpha=0.3) ax.tick_params(labelsize=config.font_size - 2)
[docs] class EDAVisualizer:
[docs] def __init__(self, output_dir: str = "data/plots"): self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.distribution_visualizer = DistributionVisualizer() self.correlation_visualizer = CorrelationVisualizer() self.training_visualizer = TrainingVisualizer() self.feature_visualizer = FeatureAnalysisVisualizer()
[docs] def create_survival_analysis_plots(self, dataset: Dataset) -> dict[str, Path]: plots = {} plot_df = dataset.features.copy() if dataset.target is not None: plot_df["Survived"] = dataset.target if dataset.target is not None: config = PlotConfig( plot_type=PlotType.PIE, title="Survival Distribution on Titanic", x_label="Survived", y_label="Percentage", figsize=(8, 8), colors=["#ff6b6b", "#4ecdc4"], ) fig, ax = plt.subplots(figsize=config.figsize) survival_counts = dataset.target.value_counts() labels = ["Did Not Survive", "Survived"] colors = config.colors if config.colors else ["#ff6b6b", "#4ecdc4"] _wedges, _texts, _autotexts = ax.pie( survival_counts.values, labels=labels, autopct="%1.1f%%", colors=colors, startangle=90, ) ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") plot_result = PlotResult( figure=fig, axes=ax, config=config, metadata={}) plots["survival_distribution"] = self._save_plot( plot_result, "survival_distribution" ) if "Pclass" in plot_df.columns and "Survived" in plot_df.columns: config = PlotConfig( plot_type=PlotType.BAR, title="Survival Rate by Passenger Class", x_label="Pclass", y_label="Survival Rate (%)", colors=["#ff9999", "#66b3ff", "#99ff99"], ) fig, ax = plt.subplots(figsize=config.figsize) survival_by_class = plot_df.groupby( "Pclass")["Survived"].mean() * 100 ax.bar( range(len(survival_by_class)), survival_by_class.values, color=config.colors[: len(survival_by_class)], alpha=0.7, edgecolor="black", ) ax.set_xlabel("Passenger Class", fontsize=config.font_size) ax.set_ylabel("Survival Rate (%)", fontsize=config.font_size) ax.set_xticks(range(len(survival_by_class))) ax.set_xticklabels( [f"Class {cls}" for cls in survival_by_class.index]) ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") for i, value in enumerate(survival_by_class.values): ax.text( i, value + 1, f"{value:.1f}%", ha="center", va="bottom", fontsize=config.font_size - 2, ) plot_result = PlotResult( figure=fig, axes=ax, config=config, metadata={}) plots["survival_by_class"] = self._save_plot( plot_result, "survival_by_class" ) return plots
[docs] def create_demographic_plots(self, dataset: Dataset) -> dict[str, Path]: plots = {} plot_df = dataset.features.copy() if dataset.target is not None: plot_df["Survived"] = dataset.target if "Sex" in plot_df.columns and "Survived" in plot_df.columns: config = PlotConfig( plot_type=PlotType.BAR, title="Survival Rate by Gender", x_label="Sex", y_label="Survival Rate (%)", colors=["#ff6b6b", "#4ecdc4"], ) fig, ax = plt.subplots(figsize=config.figsize) survival_by_sex = plot_df.groupby("Sex")["Survived"].mean() * 100 ax.bar( range(len(survival_by_sex)), survival_by_sex.values, color=config.colors[: len(survival_by_sex)], alpha=0.7, edgecolor="black", ) ax.set_xlabel("Gender", fontsize=config.font_size) ax.set_ylabel("Survival Rate (%)", fontsize=config.font_size) ax.set_xticks(range(len(survival_by_sex))) ax.set_xticklabels([sex.capitalize() for sex in survival_by_sex.index]) ax.set_title(config.title, fontsize=config.font_size + 2, fontweight="bold") for i, value in enumerate(survival_by_sex.values): ax.text( i, value + 1, f"{value:.1f}%", ha="center", va="bottom", fontsize=config.font_size - 2, ) plot_result = PlotResult( figure=fig, axes=ax, config=config, metadata={}) plots["survival_by_gender"] = self._save_plot( plot_result, "survival_by_gender" ) if "Age" in plot_df.columns: config = PlotConfig( plot_type=PlotType.HISTOGRAM, title="Age Distribution of Passengers", x_label="Age", y_label="Frequency", colors=["skyblue"], ) plot_result = self.distribution_visualizer.create_plot( plot_df["Age"].dropna(), config ) plots["age_distribution"] = self._save_plot( plot_result, "age_distribution") return plots
[docs] def create_correlation_analysis(self, dataset: Dataset) -> dict[str, Path]: plots = {} numeric_features = dataset.features.select_dtypes(include=[np.number]) correlation_matrix = numeric_features.corr() config = PlotConfig( plot_type=PlotType.HEATMAP, title="Correlation Matrix of Numerical Features", x_label="Features", y_label="Features", figsize=(10, 8), ) plot_result = self.correlation_visualizer.create_plot( correlation_matrix, config ) plots["correlation_matrix"] = self._save_plot( plot_result, "correlation_matrix") return plots
[docs] def create_training_plots(self, loss_history: list[float]) -> dict[str, Path]: plots = {} config = PlotConfig( plot_type=PlotType.LINE, title="Model Training Loss Over Epochs", x_label="Epoch", y_label="Loss", colors=["blue"], ) plot_result = self.training_visualizer.create_plot( loss_history, config) plots["training_curve"] = self._save_plot( plot_result, "training_curve") return plots
[docs] def create_feature_importance_plot(self, model: TrainedModel) -> dict[str, Path]: plots = {} importance_df = pd.DataFrame( {"feature": model.feature_names, "importance": np.abs(model.weights)} ).sort_values("importance", ascending=True) config = PlotConfig( plot_type=PlotType.BAR, title="Feature Importance", x_label="Importance", y_label="Features", figsize=(10, 8), colors=["lightsteelblue"], ) temp_dataset = Dataset( features=importance_df, target=None, feature_names=["feature", "importance"], target_name="importance", ) plot_result = self.feature_visualizer.create_plot(temp_dataset, config) plots["feature_importance"] = self._save_plot( plot_result, "feature_importance") return plots
[docs] def _save_plot(self, plot_result: PlotResult, filename: str) -> Path: filepath = self.output_dir / \ f"{filename}.{plot_result.config.save_format}" plot_result.figure.savefig( filepath, dpi=plot_result.config.dpi, bbox_inches="tight", facecolor="white" ) plt.close(plot_result.figure) logger.info(f"Plot saved: {filepath}") return filepath
[docs] def get_supported_formats(self) -> list[str]: return ["png", "jpg", "jpeg", "svg", "pdf"]