Source code for titanicprediction.data.repositories
import pickle
from pathlib import Path
from typing import Any, Protocol
import pandas as pd
from loguru import logger
from titanicprediction.entities.core import Dataset, TrainedModel
[docs]
class IDataRepository(Protocol):
[docs]
def load_data(self) -> Dataset: ...
[docs]
def save_data(self, dataset: Dataset) -> bool: ...
[docs]
class IModelRepository(Protocol):
[docs]
def save_model(self, model: TrainedModel, name: str) -> bool: ...
[docs]
def load_model(self, name: str) -> TrainedModel | None: ...
[docs]
def list_models(self) -> list[str]: ...
[docs]
class CSVDataRepository:
[docs]
def __init__(self, file_path: str, target_column: str = "Survived"):
self.file_path = Path(file_path)
self.target_column = target_column
[docs]
def load_data(self) -> Dataset:
try:
df = pd.read_csv(self.file_path)
if self.target_column not in df.columns:
raise ValueError(
f"Target column '{self.target_column}' not found")
features = df.drop(columns=[self.target_column])
target = df[self.target_column]
return Dataset(
features=features,
target=target,
feature_names=list(features.columns),
target_name=self.target_column,
)
except Exception as e:
logger.error(f"Error when loading data from {self.file_path}: {e}")
raise
[docs]
def save_data(self, dataset: Dataset) -> bool:
try:
df = dataset.features.copy()
if dataset.target is not None:
df[dataset.target_name] = dataset.target
self.file_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(self.file_path, index=False)
logger.info(f"Data saved in {self.file_path}")
return True
except Exception as e:
logger.error(f"Error when saving data to {self.file_path}: {e}")
return False
[docs]
class FileModelRepository:
[docs]
def __init__(self, models_dir: str = "models"):
self.models_dir = Path(models_dir)
self.models_dir.mkdir(exist_ok=True)
[docs]
def save_model(self, model: TrainedModel, name: str) -> bool:
try:
model_file = self.models_dir / f"{name}.pkl"
with open(model_file, "wb") as f:
pickle.dump(model, f)
logger.info(f"Model {name} saved in {model_file}")
return True
except Exception as e:
logger.error(f"Error when saving model {name}: {e}")
return False
[docs]
def load_model(self, name: str) -> TrainedModel | None:
try:
model_file = self.models_dir / f"{name}.pkl"
if not model_file.exists():
logger.warning(f"File of model {model_file} not found")
return None
with open(model_file, "rb") as f:
model = pickle.load(f)
logger.info(f"Model {name} loadded")
return model
except Exception as e:
logger.error(f"Error when loading model {name}: {e}")
return None
[docs]
def list_models(self) -> list[str]:
try:
model_files = list(self.models_dir.glob("*.pkl"))
return [f.stem for f in model_files]
except Exception as e:
logger.error(f"Error when getting list of models: {e}")
return []
[docs]
def get_model_info(self, name: str) -> dict[str, Any] | None:
model = self.load_model(name)
if model is None:
return None
return {
"name": name,
"feature_names": model.feature_names,
"metrics": model.metrics,
"weights_shape": model.weights.shape if hasattr(model, "weights") else None,
}
[docs]
def delete_model(self, name: str) -> bool:
try:
model_file = self.models_dir / f"{name}.pkl"
if model_file.exists():
model_file.unlink()
logger.info(f"Model {name} deleted")
return True
logger.warning(f"Model {name} not found for deleting")
return False
except Exception as e:
logger.error(f"Error when deleting model {name}: {e}")
return False