From e3d5f953b13e95a5301023c972cbd149e1e797e6 Mon Sep 17 00:00:00 2001 From: Michael Woods Date: Thu, 25 Dec 2025 15:35:25 -0500 Subject: [PATCH] Reapply "Suggestions from grok about database.py. Updated bulletins to use new database logic." This reverts commit 2051cda1b4d2f9137ff18c876e9b8e0ce9f5e21d. --- packetserver/http/database.py | 58 ++++++++--- packetserver/http/routers/bulletins.py | 134 +++++++++++++------------ 2 files changed, 113 insertions(+), 79 deletions(-) diff --git a/packetserver/http/database.py b/packetserver/http/database.py index 186c880..6b0bf9b 100644 --- a/packetserver/http/database.py +++ b/packetserver/http/database.py @@ -1,33 +1,65 @@ -from .config import Settings +from fastapi import Depends +from typing import Annotated, Generator from os.path import isfile + import ZEO import ZODB -from fastapi import Depends -from typing import Annotated, ContextManager +from ZODB.Connection import Connection +import transaction + +from .config import Settings # assuming Settings has zeo_file: str settings = Settings() -def get_zeo_address(zeo_address_file: str) -> tuple[str,int]: +# Global shared DB instance (created once) +_db: ZODB.DB | None = None + +def _get_zeo_address(zeo_address_file: str) -> tuple[str, int]: if not isfile(zeo_address_file): - raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'") + raise FileNotFoundError(f"ZEO address file not found: '{zeo_address_file}'") contents = open(zeo_address_file, 'r').read().strip().split(":") - if len(contents) != 2: - raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") + raise ValueError(f"Invalid ZEO address format in {zeo_address_file}") host = contents[0] try: port = int(contents[1]) except ValueError: - raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") - return host,port + raise ValueError(f"Invalid port in ZEO address file: {zeo_address_file}") + + return host, port + +def init_db() -> ZODB.DB: + """Call this on app startup to create the shared DB instance.""" + global _db + if _db is not None: + return _db + + host, port = _get_zeo_address(settings.zeo_file) + storage = ZEO.ClientStorage((host, port)) + _db = ZODB.DB(storage) + return _db def get_db() -> ZODB.DB: - return ZEO.DB(get_zeo_address(settings.zeo_file)) + """Dependency for the shared DB instance (e.g., for class methods needing DB).""" + if _db is None: + raise RuntimeError("Database not initialized – call init_db() on startup") + return _db -def get_transaction() -> ContextManager: - return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction() +def get_connection() -> Generator[Connection, None, None]: + """Per-request dependency: yields an open Connection, closes on exit.""" + db = get_db() + conn = db.open() + try: + yield conn + finally: + conn.close() +# Optional: per-request transaction (if you want automatic commit/abort) +def get_transaction_manager(): + return transaction.manager + +# Annotated dependencies for routers DbDependency = Annotated[ZODB.DB, Depends(get_db)] -TransactionDependency = Annotated[ContextManager, Depends(get_transaction)] +ConnectionDependency = Annotated[Connection, Depends(get_connection)] \ No newline at end of file diff --git a/packetserver/http/routers/bulletins.py b/packetserver/http/routers/bulletins.py index 660c96f..f371568 100644 --- a/packetserver/http/routers/bulletins.py +++ b/packetserver/http/routers/bulletins.py @@ -7,7 +7,7 @@ import transaction from persistent.list import PersistentList from ZODB.Connection import Connection -from ..database import DbDependency, TransactionDependency +from packetserver.http.database import DbDependency, ConnectionDependency, get_db from ..dependencies import get_current_http_user from ..auth import HttpUser from ..server import templates @@ -48,35 +48,36 @@ async def list_bulletins(connection: Connection, limit: int = 50, since: Optiona @router.get("/bulletins") async def api_list_bulletins( + db: DbDependency, limit: Optional[int] = Query(50, le=100), since: Optional[datetime] = None, - current_user: HttpUser = Depends(get_current_http_user) ): - return await list_bulletins(limit=limit, since=since) + with db.transaction() as conn: + return await list_bulletins(conn, limit=limit, since=since) -async def get_one_bulletin(bid: int) -> dict: - with get_transaction() as conn: - root = conn.root() - bulletins_list: List[Bulletin] = root.get("bulletins", []) +async def get_one_bulletin(connection: Connection, bid: int) -> dict: + root = connection.root() + bulletins_list: List[Bulletin] = root.get("bulletins", []) - for b in bulletins_list: - if b.id == bid: - return { - "id": b.id, - "author": b.author, - "subject": b.subject, - "body": b.body, - "created_at": b.created_at.isoformat() + "Z", - "updated_at": b.updated_at.isoformat() + "Z", - } - raise HTTPException(status_code=404, detail="Bulletin not found") + for b in bulletins_list: + if b.id == bid: + return { + "id": b.id, + "author": b.author, + "subject": b.subject, + "body": b.body, + "created_at": b.created_at.isoformat() + "Z", + "updated_at": b.updated_at.isoformat() + "Z", + } + raise HTTPException(status_code=404, detail="Bulletin not found") @router.get("/bulletins/{bid}") async def api_get_bulletin( + db: DbDependency, bid: int, - current_user: HttpUser = Depends(get_current_http_user) ): - return await get_one_bulletin(bid) + with db.transaction() as conn: + return await get_one_bulletin(conn, bid) class CreateBulletinRequest(BaseModel): subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title") @@ -84,55 +85,56 @@ class CreateBulletinRequest(BaseModel): @router.post("/bulletins", status_code=status.HTTP_201_CREATED) async def create_bulletin( + db: DbDependency, payload: CreateBulletinRequest, current_user: HttpUser = Depends(get_current_http_user) ): - from packetserver.runners.http_server import get_db_connection - conn = get_db_connection() - root = conn.root() + with db.transaction() as conn: + root = conn.root() - if 'bulletins' not in root: - root['bulletins'] = PersistentList() + if 'bulletins' not in root: + root['bulletins'] = PersistentList() - new_bulletin = Bulletin( - author=current_user.username, - subject=payload.subject.strip(), - text=payload.body.strip() - ) + new_bulletin = Bulletin( + author=current_user.username, + subject=payload.subject.strip(), + text=payload.body.strip() + ) - new_id = new_bulletin.write_new(root) + new_id = new_bulletin.write_new(root) - transaction.commit() + transaction.commit() - return { - "id": new_id, - "author": new_bulletin.author, - "subject": new_bulletin.subject, - "body": new_bulletin.body, - "created_at": new_bulletin.created_at.isoformat() + "Z", - "updated_at": new_bulletin.updated_at.isoformat() + "Z", - } + return { + "id": new_id, + "author": new_bulletin.author, + "subject": new_bulletin.subject, + "body": new_bulletin.body, + "created_at": new_bulletin.created_at.isoformat() + "Z", + "updated_at": new_bulletin.updated_at.isoformat() + "Z", + } # --- HTML Pages (require login) --- @html_router.get("/bulletins", response_class=HTMLResponse) async def bulletin_list_page( + db: DbDependency, request: Request, limit: Optional[int] = Query(50, le=100), current_user: HttpUser = Depends(get_current_http_user) ): - api_resp = await list_bulletins(limit=limit, since=None) - bulletins = api_resp["bulletins"] + with db.transaction() as conn: + api_resp = await list_bulletins(conn, limit=limit, since=None) + bulletins = api_resp["bulletins"] - return templates.TemplateResponse( - "bulletin_list.html", - {"request": request, "bulletins": bulletins, "current_user": current_user.username} - ) + return templates.TemplateResponse( + "bulletin_list.html", + {"request": request, "bulletins": bulletins, "current_user": current_user.username} + ) @html_router.get("/bulletins/new", response_class=HTMLResponse) async def bulletin_new_form( request: Request, - current_user: HttpUser = Depends(get_current_http_user) # require login, consistent with site ): return templates.TemplateResponse( "bulletin_new.html", @@ -141,6 +143,7 @@ async def bulletin_new_form( @html_router.post("/bulletins/new") async def bulletin_new_submit( + db: DbDependency, request: Request, subject: str = Form(...), body: str = Form(...), @@ -152,34 +155,33 @@ async def bulletin_new_submit( {"request": request, "error": "Subject and body are required."}, status_code=400 ) - from packetserver.runners.http_server import get_db_connection - conn = get_db_connection() - root = conn.root() + with db.transaction() as conn: + root = conn.root() - if 'bulletins' not in root: - root['bulletins'] = PersistentList() + if 'bulletins' not in root: + root['bulletins'] = PersistentList() - new_bulletin = Bulletin( - author=current_user.username, - subject=subject.strip(), - text=body.strip() - ) + new_bulletin = Bulletin( + author=current_user.username, + subject=subject.strip(), + text=body.strip() + ) - new_id = new_bulletin.write_new(root) + new_id = new_bulletin.write_new(root) - transaction.commit() - - return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303) + return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303) @html_router.get("/bulletins/{bid}", response_class=HTMLResponse) async def bulletin_detail_page( + db: DbDependency, request: Request, bid: int = Path(...), current_user: HttpUser = Depends(get_current_http_user) ): - bulletin = await get_one_bulletin(bid=bid) + with db.transaction() as conn: + bulletin = await get_one_bulletin(conn, bid=bid) - return templates.TemplateResponse( - "bulletin_detail.html", - {"request": request, "bulletin": bulletin, "current_user": current_user.username} - ) \ No newline at end of file + return templates.TemplateResponse( + "bulletin_detail.html", + {"request": request, "bulletin": bulletin, "current_user": current_user.username} + ) \ No newline at end of file