from abc import ABC, abstractmethod
from typing import Any
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import func, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from orghandbookapi.database.models.activity import Activity
from orghandbookapi.database.models.base import Base
from orghandbookapi.database.models.building import Building
from orghandbookapi.database.models.organization import (
Organization,
PhoneNumber,
organization_activity,
)
[docs]
async def commit_process_session( # noqa: D103
session: AsyncSession, candidate: Any = None, *, flush: bool = False
):
if flush:
await session.flush()
await session.commit()
if candidate is not None:
await session.refresh(candidate)
[docs]
class CRUDRepository(ABC): # noqa: D101
[docs]
@classmethod
@abstractmethod
async def get(cls, session: AsyncSession, id: int) -> Base | None: # noqa: A002, D102
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
async def get_with_relations(cls, session: AsyncSession, id: int) -> Base | None: # noqa: A002, D102
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
async def delete(cls, session: AsyncSession, id: int): # noqa: A002, D102
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
async def update(cls, session: AsyncSession, model: BaseModel): # noqa: D102
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
async def create(cls, session: AsyncSession, model: BaseModel) -> Base: # noqa: D102
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
async def get_all(cls, session: AsyncSession) -> list[Base]: # noqa: D102
raise NotImplementedError
[docs]
class OrganizationRepository(CRUDRepository): # noqa: D101
[docs]
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Base | None: # noqa: A002, D102
result = await session.execute(
select(Organization).where(Organization.id == id)
)
return result.scalar_one_or_none()
[docs]
@classmethod
async def get_with_relations( # noqa: D102
cls,
session: AsyncSession,
id: int, # noqa: A002
) -> Organization | None:
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.where(Organization.id == id)
)
return result.scalar_one_or_none()
[docs]
@classmethod
async def delete(cls, session: AsyncSession, id: int): # noqa: A002, D102
organization = await cls.get(session, id)
if organization:
await session.delete(organization)
await commit_process_session(session)
[docs]
@classmethod
async def update(cls, session: AsyncSession, model: BaseModel) -> Organization: # noqa: D102
update_data = model.dict(exclude_unset=True)
await session.execute(
update(Organization)
.where(Organization.id == update_data["id"])
.values(**{k: v for k, v in update_data.items() if k != "id"})
)
await commit_process_session(session)
return await cls.get(session, update_data["id"])
[docs]
@classmethod
async def create(cls, session: AsyncSession, model: BaseModel) -> Organization: # noqa: D102
organization_data = model.dict(exclude={"phone_numbers", "activity_ids"})
organization = Organization(**organization_data)
session.add(organization)
await session.flush()
for phone in model.phone_numbers:
phone_obj = PhoneNumber(phone_number=phone, organization_id=organization.id)
session.add(phone_obj)
if model.activity_ids:
stmt = insert(organization_activity).values(
[
{"organization_id": organization.id, "activity_id": activity_id}
for activity_id in model.activity_ids
]
)
await session.execute(stmt)
await commit_process_session(session, organization)
return organization
[docs]
@classmethod
async def get_all(cls, session: AsyncSession) -> list[Organization]: # noqa: D102
result = await session.execute(select(Organization))
return result.scalars().all()
[docs]
@classmethod
async def get_by_building( # noqa: D102
cls, session: AsyncSession, building_id: int
) -> list[Organization]:
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.where(Organization.building_id == building_id)
)
return result.scalars().all()
[docs]
@classmethod
async def get_by_activity( # noqa: D102
cls, session: AsyncSession, activity_id: int
) -> list[Organization]:
activities_cte = (
select(Activity.id, Activity.parent_id, Activity.level)
.where(Activity.id == activity_id)
.cte(name="activities_cte", recursive=True)
)
recursive_select = (
select(Activity.id, Activity.parent_id, Activity.level)
.join(activities_cte, Activity.parent_id == activities_cte.c.id)
.where(Activity.level < 3)
)
activities_cte = activities_cte.union_all(recursive_select)
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.join(Organization.activities)
.where(Activity.id.in_(select(activities_cte.c.id)))
)
return result.scalars().all()
[docs]
@classmethod
async def search_by_name( # noqa: D102
cls, session: AsyncSession, name: str
) -> list[Organization]:
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.where(Organization.name.ilike(f"%{name}%"))
)
return result.scalars().all()
[docs]
@classmethod
async def get_in_radius( # noqa: D102
cls, session: AsyncSession, lat: float, lon: float, radius_km: float
) -> list[Organization]:
distance = (
func.acos(
func.sin(func.radians(lat)) * func.sin(func.radians(Building.latitude))
+ func.cos(func.radians(lat))
* func.cos(func.radians(Building.latitude))
* func.cos(func.radians(Building.longitude) - func.radians(lon))
)
* 6371
)
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.join(Organization.building)
.where(distance <= radius_km)
)
return result.scalars().all()
[docs]
@classmethod
async def get_in_rectangular_area( # noqa: D102
cls,
session: AsyncSession,
min_lat: float,
max_lat: float,
min_lon: float,
max_lon: float,
) -> list[Organization]:
result = await session.execute(
select(Organization)
.options(
selectinload(Organization.building),
selectinload(Organization.phonenumbers),
selectinload(Organization.activities),
)
.join(Organization.building)
.where(
Building.latitude.between(min_lat, max_lat),
Building.longitude.between(min_lon, max_lon),
)
)
return result.scalars().all()
[docs]
class BuildingRepository(CRUDRepository): # noqa: D101
[docs]
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Base | None: # noqa: A002, D102
result = await session.execute(select(Building).where(Building.id == id))
return result.scalar_one_or_none()
[docs]
@classmethod
async def get_with_relations( # noqa: D102
cls,
session: AsyncSession,
id: int, # noqa: A002
) -> Building | None:
result = await session.execute(
select(Building)
.options(selectinload(Building.organizations))
.where(Building.id == id)
)
return result.scalar_one_or_none()
[docs]
@classmethod
async def delete(cls, session: AsyncSession, id: int): # noqa: A002, D102
building = await cls.get(session, id)
if building:
await session.delete(building)
await commit_process_session(session)
[docs]
@classmethod
async def update(cls, session: AsyncSession, model: BaseModel): # noqa: D102
update_data = model.dict(exclude_unset=True)
await session.execute(
update(Building)
.where(Building.id == update_data["id"])
.values(**{k: v for k, v in update_data.items() if k != "id"})
)
await commit_process_session(session)
[docs]
@classmethod
async def create(cls, session: AsyncSession, model: BaseModel) -> Building: # noqa: D102
building_data = model.dict()
building = Building(**building_data)
session.add(building)
await commit_process_session(session, building)
return building
[docs]
@classmethod
async def get_all(cls, session: AsyncSession) -> list[Building]: # noqa: D102
result = await session.execute(select(Building))
return result.scalars().all()
[docs]
class ActivityRepository(CRUDRepository): # noqa: D101
[docs]
@classmethod
async def get(cls, session: AsyncSession, id: int) -> Base | None: # noqa: A002, D102
result = await session.execute(select(Activity).where(Activity.id == id))
return result.scalar_one_or_none()
[docs]
@classmethod
async def get_with_relations( # noqa: D102
cls,
session: AsyncSession,
id: int, # noqa: A002
) -> Activity | None:
result = await session.execute(
select(Activity)
.options(
selectinload(Activity.parent),
selectinload(Activity.children),
selectinload(Activity.organizations),
)
.where(Activity.id == id)
)
return result.scalar_one_or_none()
[docs]
@classmethod
async def delete(cls, session: AsyncSession, id: int): # noqa: A002, D102
activity = await cls.get(session, id)
if activity:
await session.delete(activity)
await commit_process_session(session)
[docs]
@classmethod
async def update(cls, session: AsyncSession, model: BaseModel): # noqa: D102
update_data = model.dict(exclude_unset=True)
await session.execute(
update(Activity)
.where(Activity.id == update_data["id"])
.values(**{k: v for k, v in update_data.items() if k != "id"})
)
await commit_process_session(session)
[docs]
@classmethod
async def create(cls, session: AsyncSession, model: BaseModel) -> Activity: # noqa: D102
activity_data = model.dict()
if activity_data.get("parent_id"):
parent_activity = await cls.get(session, activity_data["parent_id"])
if parent_activity:
if parent_activity.level >= 2:
raise HTTPException(
status_code=400, detail="Maximum nesting level is 3"
)
activity_data["level"] = parent_activity.level + 1
else:
raise HTTPException(status_code=404, detail="Parent activity not found")
else:
activity_data["level"] = 0
activity = Activity(**activity_data)
session.add(activity)
await commit_process_session(session, activity)
return activity
[docs]
@classmethod
async def get_all(cls, session: AsyncSession) -> list[Activity]: # noqa: D102
result = await session.execute(select(Activity))
return result.scalars().all()
[docs]
@classmethod
async def get_tree( # noqa: D102
cls, session: AsyncSession, parent_id: int | None = None
) -> list[Activity]:
if parent_id is None:
result = await session.execute(
select(Activity).where(Activity.parent_id.is_(None))
)
else:
result = await session.execute(
select(Activity).where(Activity.parent_id == parent_id)
)
return result.scalars().all()