Suggestions from grok about database.py. Updated bulletins to use new database logic.

This commit is contained in:
Michael Woods
2025-12-25 15:26:41 -05:00
parent 65063704e0
commit 60165d658c
2 changed files with 113 additions and 79 deletions

View File

@@ -1,33 +1,65 @@
from .config import Settings from fastapi import Depends
from typing import Annotated, Generator
from os.path import isfile from os.path import isfile
import ZEO import ZEO
import ZODB import ZODB
from fastapi import Depends from ZODB.Connection import Connection
from typing import Annotated, ContextManager import transaction
from .config import Settings # assuming Settings has zeo_file: str
settings = Settings() settings = Settings()
def get_zeo_address(zeo_address_file: str) -> tuple[str,int]: # Global shared DB instance (created once)
_db: ZODB.DB | None = None
def _get_zeo_address(zeo_address_file: str) -> tuple[str, int]:
if not isfile(zeo_address_file): if not isfile(zeo_address_file):
raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'") raise FileNotFoundError(f"ZEO address file not found: '{zeo_address_file}'")
contents = open(zeo_address_file, 'r').read().strip().split(":") contents = open(zeo_address_file, 'r').read().strip().split(":")
if len(contents) != 2: if len(contents) != 2:
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") raise ValueError(f"Invalid ZEO address format in {zeo_address_file}")
host = contents[0] host = contents[0]
try: try:
port = int(contents[1]) port = int(contents[1])
except ValueError: except ValueError:
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}") raise ValueError(f"Invalid port in ZEO address file: {zeo_address_file}")
return host, port return host, port
def init_db() -> ZODB.DB:
"""Call this on app startup to create the shared DB instance."""
global _db
if _db is not None:
return _db
host, port = _get_zeo_address(settings.zeo_file)
storage = ZEO.ClientStorage((host, port))
_db = ZODB.DB(storage)
return _db
def get_db() -> ZODB.DB: def get_db() -> ZODB.DB:
return ZEO.DB(get_zeo_address(settings.zeo_file)) """Dependency for the shared DB instance (e.g., for class methods needing DB)."""
if _db is None:
raise RuntimeError("Database not initialized call init_db() on startup")
return _db
def get_transaction() -> ContextManager: def get_connection() -> Generator[Connection, None, None]:
return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction() """Per-request dependency: yields an open Connection, closes on exit."""
db = get_db()
conn = db.open()
try:
yield conn
finally:
conn.close()
# Optional: per-request transaction (if you want automatic commit/abort)
def get_transaction_manager():
return transaction.manager
# Annotated dependencies for routers
DbDependency = Annotated[ZODB.DB, Depends(get_db)] DbDependency = Annotated[ZODB.DB, Depends(get_db)]
TransactionDependency = Annotated[ContextManager, Depends(get_transaction)] ConnectionDependency = Annotated[Connection, Depends(get_connection)]

View File

@@ -7,7 +7,7 @@ import transaction
from persistent.list import PersistentList from persistent.list import PersistentList
from ZODB.Connection import Connection from ZODB.Connection import Connection
from ..database import DbDependency, TransactionDependency from packetserver.http.database import DbDependency, ConnectionDependency, get_db
from ..dependencies import get_current_http_user from ..dependencies import get_current_http_user
from ..auth import HttpUser from ..auth import HttpUser
from ..server import templates from ..server import templates
@@ -48,15 +48,15 @@ async def list_bulletins(connection: Connection, limit: int = 50, since: Optiona
@router.get("/bulletins") @router.get("/bulletins")
async def api_list_bulletins( async def api_list_bulletins(
db: DbDependency,
limit: Optional[int] = Query(50, le=100), limit: Optional[int] = Query(50, le=100),
since: Optional[datetime] = None, since: Optional[datetime] = None,
current_user: HttpUser = Depends(get_current_http_user)
): ):
return await list_bulletins(limit=limit, since=since) with db.transaction() as conn:
return await list_bulletins(conn, limit=limit, since=since)
async def get_one_bulletin(bid: int) -> dict: async def get_one_bulletin(connection: Connection, bid: int) -> dict:
with get_transaction() as conn: root = connection.root()
root = conn.root()
bulletins_list: List[Bulletin] = root.get("bulletins", []) bulletins_list: List[Bulletin] = root.get("bulletins", [])
for b in bulletins_list: for b in bulletins_list:
@@ -73,10 +73,11 @@ async def get_one_bulletin(bid: int) -> dict:
@router.get("/bulletins/{bid}") @router.get("/bulletins/{bid}")
async def api_get_bulletin( async def api_get_bulletin(
db: DbDependency,
bid: int, bid: int,
current_user: HttpUser = Depends(get_current_http_user)
): ):
return await get_one_bulletin(bid) with db.transaction() as conn:
return await get_one_bulletin(conn, bid)
class CreateBulletinRequest(BaseModel): class CreateBulletinRequest(BaseModel):
subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title") subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title")
@@ -84,11 +85,11 @@ class CreateBulletinRequest(BaseModel):
@router.post("/bulletins", status_code=status.HTTP_201_CREATED) @router.post("/bulletins", status_code=status.HTTP_201_CREATED)
async def create_bulletin( async def create_bulletin(
db: DbDependency,
payload: CreateBulletinRequest, payload: CreateBulletinRequest,
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
from packetserver.runners.http_server import get_db_connection with db.transaction() as conn:
conn = get_db_connection()
root = conn.root() root = conn.root()
if 'bulletins' not in root: if 'bulletins' not in root:
@@ -117,11 +118,13 @@ async def create_bulletin(
@html_router.get("/bulletins", response_class=HTMLResponse) @html_router.get("/bulletins", response_class=HTMLResponse)
async def bulletin_list_page( async def bulletin_list_page(
db: DbDependency,
request: Request, request: Request,
limit: Optional[int] = Query(50, le=100), limit: Optional[int] = Query(50, le=100),
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
api_resp = await list_bulletins(limit=limit, since=None) with db.transaction() as conn:
api_resp = await list_bulletins(conn, limit=limit, since=None)
bulletins = api_resp["bulletins"] bulletins = api_resp["bulletins"]
return templates.TemplateResponse( return templates.TemplateResponse(
@@ -132,7 +135,6 @@ async def bulletin_list_page(
@html_router.get("/bulletins/new", response_class=HTMLResponse) @html_router.get("/bulletins/new", response_class=HTMLResponse)
async def bulletin_new_form( async def bulletin_new_form(
request: Request, request: Request,
current_user: HttpUser = Depends(get_current_http_user) # require login, consistent with site
): ):
return templates.TemplateResponse( return templates.TemplateResponse(
"bulletin_new.html", "bulletin_new.html",
@@ -141,6 +143,7 @@ async def bulletin_new_form(
@html_router.post("/bulletins/new") @html_router.post("/bulletins/new")
async def bulletin_new_submit( async def bulletin_new_submit(
db: DbDependency,
request: Request, request: Request,
subject: str = Form(...), subject: str = Form(...),
body: str = Form(...), body: str = Form(...),
@@ -152,8 +155,7 @@ async def bulletin_new_submit(
{"request": request, "error": "Subject and body are required."}, {"request": request, "error": "Subject and body are required."},
status_code=400 status_code=400
) )
from packetserver.runners.http_server import get_db_connection with db.transaction() as conn:
conn = get_db_connection()
root = conn.root() root = conn.root()
if 'bulletins' not in root: if 'bulletins' not in root:
@@ -167,17 +169,17 @@ async def bulletin_new_submit(
new_id = new_bulletin.write_new(root) new_id = new_bulletin.write_new(root)
transaction.commit()
return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303) return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303)
@html_router.get("/bulletins/{bid}", response_class=HTMLResponse) @html_router.get("/bulletins/{bid}", response_class=HTMLResponse)
async def bulletin_detail_page( async def bulletin_detail_page(
db: DbDependency,
request: Request, request: Request,
bid: int = Path(...), bid: int = Path(...),
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
bulletin = await get_one_bulletin(bid=bid) with db.transaction() as conn:
bulletin = await get_one_bulletin(conn, bid=bid)
return templates.TemplateResponse( return templates.TemplateResponse(
"bulletin_detail.html", "bulletin_detail.html",