From 2051cda1b4d2f9137ff18c876e9b8e0ce9f5e21d Mon Sep 17 00:00:00 2001 From: Michael Woods Date: Thu, 25 Dec 2025 15:32:21 -0500 Subject: [PATCH] Revert "Suggestions from grok about database.py. Updated bulletins to use new database logic." This reverts commit 60165d658c9386aad688957c12e873500bcc3aaf. --- packetserver/http/database.py | 58 +++-------- packetserver/http/routers/bulletins.py | 134 ++++++++++++------------- 2 files changed, 79 insertions(+), 113 deletions(-) diff --git a/packetserver/http/database.py b/packetserver/http/database.py index 6b0bf9b..186c880 100644 --- a/packetserver/http/database.py +++ b/packetserver/http/database.py @@ -1,65 +1,33 @@ -from fastapi import Depends -from typing import Annotated, Generator +from .config import Settings from os.path import isfile - import ZEO import ZODB -from ZODB.Connection import Connection -import transaction - -from .config import Settings # assuming Settings has zeo_file: str +from fastapi import Depends +from typing import Annotated, ContextManager settings = Settings() -# Global shared DB instance (created once) -_db: ZODB.DB | None = None - -def _get_zeo_address(zeo_address_file: str) -> tuple[str, int]: +def get_zeo_address(zeo_address_file: str) -> tuple[str,int]: if not isfile(zeo_address_file): - raise FileNotFoundError(f"ZEO address file not found: '{zeo_address_file}'") + raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'") contents = open(zeo_address_file, 'r').read().strip().split(":") + if len(contents) != 2: - raise ValueError(f"Invalid ZEO address format in {zeo_address_file}") + raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") host = contents[0] try: port = int(contents[1]) except ValueError: - raise ValueError(f"Invalid port in ZEO address file: {zeo_address_file}") - - return host, port - -def init_db() -> ZODB.DB: - """Call this on app startup to create the shared DB instance.""" - global _db - if _db is not None: - return _db - - host, port = _get_zeo_address(settings.zeo_file) - storage = ZEO.ClientStorage((host, port)) - _db = ZODB.DB(storage) - return _db + raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") + return host,port def get_db() -> ZODB.DB: - """Dependency for the shared DB instance (e.g., for class methods needing DB).""" - if _db is None: - raise RuntimeError("Database not initialized – call init_db() on startup") - return _db + return ZEO.DB(get_zeo_address(settings.zeo_file)) -def get_connection() -> Generator[Connection, None, None]: - """Per-request dependency: yields an open Connection, closes on exit.""" - db = get_db() - conn = db.open() - try: - yield conn - finally: - conn.close() +def get_transaction() -> ContextManager: + return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction() -# Optional: per-request transaction (if you want automatic commit/abort) -def get_transaction_manager(): - return transaction.manager - -# Annotated dependencies for routers DbDependency = Annotated[ZODB.DB, Depends(get_db)] -ConnectionDependency = Annotated[Connection, Depends(get_connection)] \ No newline at end of file +TransactionDependency = Annotated[ContextManager, Depends(get_transaction)] diff --git a/packetserver/http/routers/bulletins.py b/packetserver/http/routers/bulletins.py index f371568..660c96f 100644 --- a/packetserver/http/routers/bulletins.py +++ b/packetserver/http/routers/bulletins.py @@ -7,7 +7,7 @@ import transaction from persistent.list import PersistentList from ZODB.Connection import Connection -from packetserver.http.database import DbDependency, ConnectionDependency, get_db +from ..database import DbDependency, TransactionDependency from ..dependencies import get_current_http_user from ..auth import HttpUser from ..server import templates @@ -48,36 +48,35 @@ async def list_bulletins(connection: Connection, limit: int = 50, since: Optiona @router.get("/bulletins") async def api_list_bulletins( - db: DbDependency, limit: Optional[int] = Query(50, le=100), since: Optional[datetime] = None, + current_user: HttpUser = Depends(get_current_http_user) ): - with db.transaction() as conn: - return await list_bulletins(conn, limit=limit, since=since) + return await list_bulletins(limit=limit, since=since) -async def get_one_bulletin(connection: Connection, bid: int) -> dict: - root = connection.root() - bulletins_list: List[Bulletin] = root.get("bulletins", []) +async def get_one_bulletin(bid: int) -> dict: + with get_transaction() as conn: + root = conn.root() + bulletins_list: List[Bulletin] = root.get("bulletins", []) - for b in bulletins_list: - if b.id == bid: - return { - "id": b.id, - "author": b.author, - "subject": b.subject, - "body": b.body, - "created_at": b.created_at.isoformat() + "Z", - "updated_at": b.updated_at.isoformat() + "Z", - } - raise HTTPException(status_code=404, detail="Bulletin not found") + for b in bulletins_list: + if b.id == bid: + return { + "id": b.id, + "author": b.author, + "subject": b.subject, + "body": b.body, + "created_at": b.created_at.isoformat() + "Z", + "updated_at": b.updated_at.isoformat() + "Z", + } + raise HTTPException(status_code=404, detail="Bulletin not found") @router.get("/bulletins/{bid}") async def api_get_bulletin( - db: DbDependency, bid: int, + current_user: HttpUser = Depends(get_current_http_user) ): - with db.transaction() as conn: - return await get_one_bulletin(conn, bid) + return await get_one_bulletin(bid) class CreateBulletinRequest(BaseModel): subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title") @@ -85,56 +84,55 @@ class CreateBulletinRequest(BaseModel): @router.post("/bulletins", status_code=status.HTTP_201_CREATED) async def create_bulletin( - db: DbDependency, payload: CreateBulletinRequest, current_user: HttpUser = Depends(get_current_http_user) ): - with db.transaction() as conn: - root = conn.root() + from packetserver.runners.http_server import get_db_connection + conn = get_db_connection() + root = conn.root() - if 'bulletins' not in root: - root['bulletins'] = PersistentList() + if 'bulletins' not in root: + root['bulletins'] = PersistentList() - new_bulletin = Bulletin( - author=current_user.username, - subject=payload.subject.strip(), - text=payload.body.strip() - ) + new_bulletin = Bulletin( + author=current_user.username, + subject=payload.subject.strip(), + text=payload.body.strip() + ) - new_id = new_bulletin.write_new(root) + new_id = new_bulletin.write_new(root) - transaction.commit() + transaction.commit() - return { - "id": new_id, - "author": new_bulletin.author, - "subject": new_bulletin.subject, - "body": new_bulletin.body, - "created_at": new_bulletin.created_at.isoformat() + "Z", - "updated_at": new_bulletin.updated_at.isoformat() + "Z", - } + return { + "id": new_id, + "author": new_bulletin.author, + "subject": new_bulletin.subject, + "body": new_bulletin.body, + "created_at": new_bulletin.created_at.isoformat() + "Z", + "updated_at": new_bulletin.updated_at.isoformat() + "Z", + } # --- HTML Pages (require login) --- @html_router.get("/bulletins", response_class=HTMLResponse) async def bulletin_list_page( - db: DbDependency, request: Request, limit: Optional[int] = Query(50, le=100), current_user: HttpUser = Depends(get_current_http_user) ): - with db.transaction() as conn: - api_resp = await list_bulletins(conn, limit=limit, since=None) - bulletins = api_resp["bulletins"] + api_resp = await list_bulletins(limit=limit, since=None) + bulletins = api_resp["bulletins"] - return templates.TemplateResponse( - "bulletin_list.html", - {"request": request, "bulletins": bulletins, "current_user": current_user.username} - ) + return templates.TemplateResponse( + "bulletin_list.html", + {"request": request, "bulletins": bulletins, "current_user": current_user.username} + ) @html_router.get("/bulletins/new", response_class=HTMLResponse) async def bulletin_new_form( request: Request, + current_user: HttpUser = Depends(get_current_http_user) # require login, consistent with site ): return templates.TemplateResponse( "bulletin_new.html", @@ -143,7 +141,6 @@ async def bulletin_new_form( @html_router.post("/bulletins/new") async def bulletin_new_submit( - db: DbDependency, request: Request, subject: str = Form(...), body: str = Form(...), @@ -155,33 +152,34 @@ async def bulletin_new_submit( {"request": request, "error": "Subject and body are required."}, status_code=400 ) - with db.transaction() as conn: - root = conn.root() + from packetserver.runners.http_server import get_db_connection + conn = get_db_connection() + root = conn.root() - if 'bulletins' not in root: - root['bulletins'] = PersistentList() + if 'bulletins' not in root: + root['bulletins'] = PersistentList() - new_bulletin = Bulletin( - author=current_user.username, - subject=subject.strip(), - text=body.strip() - ) + new_bulletin = Bulletin( + author=current_user.username, + subject=subject.strip(), + text=body.strip() + ) - new_id = new_bulletin.write_new(root) + new_id = new_bulletin.write_new(root) - return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303) + transaction.commit() + + return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303) @html_router.get("/bulletins/{bid}", response_class=HTMLResponse) async def bulletin_detail_page( - db: DbDependency, request: Request, bid: int = Path(...), current_user: HttpUser = Depends(get_current_http_user) ): - with db.transaction() as conn: - bulletin = await get_one_bulletin(conn, bid=bid) + bulletin = await get_one_bulletin(bid=bid) - return templates.TemplateResponse( - "bulletin_detail.html", - {"request": request, "bulletin": bulletin, "current_user": current_user.username} - ) \ No newline at end of file + return templates.TemplateResponse( + "bulletin_detail.html", + {"request": request, "bulletin": bulletin, "current_user": current_user.username} + ) \ No newline at end of file