Reapply "Suggestions from grok about database.py. Updated bulletins to use new database logic."

This reverts commit 2051cda1b4.
This commit is contained in:
Michael Woods
2025-12-25 15:35:25 -05:00
parent 2051cda1b4
commit e3d5f953b1
2 changed files with 113 additions and 79 deletions

View File

@@ -1,33 +1,65 @@
from .config import Settings
from fastapi import Depends
from typing import Annotated, Generator
from os.path import isfile
import ZEO
import ZODB
from fastapi import Depends
from typing import Annotated, ContextManager
from ZODB.Connection import Connection
import transaction
from .config import Settings # assuming Settings has zeo_file: str
settings = Settings()
def get_zeo_address(zeo_address_file: str) -> tuple[str,int]:
# Global shared DB instance (created once)
_db: ZODB.DB | None = None
def _get_zeo_address(zeo_address_file: str) -> tuple[str, int]:
if not isfile(zeo_address_file):
raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'")
raise FileNotFoundError(f"ZEO address file not found: '{zeo_address_file}'")
contents = open(zeo_address_file, 'r').read().strip().split(":")
if len(contents) != 2:
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
raise ValueError(f"Invalid ZEO address format in {zeo_address_file}")
host = contents[0]
try:
port = int(contents[1])
except ValueError:
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
return host,port
raise ValueError(f"Invalid port in ZEO address file: {zeo_address_file}")
return host, port
def init_db() -> ZODB.DB:
"""Call this on app startup to create the shared DB instance."""
global _db
if _db is not None:
return _db
host, port = _get_zeo_address(settings.zeo_file)
storage = ZEO.ClientStorage((host, port))
_db = ZODB.DB(storage)
return _db
def get_db() -> ZODB.DB:
return ZEO.DB(get_zeo_address(settings.zeo_file))
"""Dependency for the shared DB instance (e.g., for class methods needing DB)."""
if _db is None:
raise RuntimeError("Database not initialized call init_db() on startup")
return _db
def get_transaction() -> ContextManager:
return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction()
def get_connection() -> Generator[Connection, None, None]:
"""Per-request dependency: yields an open Connection, closes on exit."""
db = get_db()
conn = db.open()
try:
yield conn
finally:
conn.close()
# Optional: per-request transaction (if you want automatic commit/abort)
def get_transaction_manager():
return transaction.manager
# Annotated dependencies for routers
DbDependency = Annotated[ZODB.DB, Depends(get_db)]
TransactionDependency = Annotated[ContextManager, Depends(get_transaction)]
ConnectionDependency = Annotated[Connection, Depends(get_connection)]

View File

@@ -7,7 +7,7 @@ import transaction
from persistent.list import PersistentList
from ZODB.Connection import Connection
from ..database import DbDependency, TransactionDependency
from packetserver.http.database import DbDependency, ConnectionDependency, get_db
from ..dependencies import get_current_http_user
from ..auth import HttpUser
from ..server import templates
@@ -48,15 +48,15 @@ async def list_bulletins(connection: Connection, limit: int = 50, since: Optiona
@router.get("/bulletins")
async def api_list_bulletins(
db: DbDependency,
limit: Optional[int] = Query(50, le=100),
since: Optional[datetime] = None,
current_user: HttpUser = Depends(get_current_http_user)
):
return await list_bulletins(limit=limit, since=since)
with db.transaction() as conn:
return await list_bulletins(conn, limit=limit, since=since)
async def get_one_bulletin(bid: int) -> dict:
with get_transaction() as conn:
root = conn.root()
async def get_one_bulletin(connection: Connection, bid: int) -> dict:
root = connection.root()
bulletins_list: List[Bulletin] = root.get("bulletins", [])
for b in bulletins_list:
@@ -73,10 +73,11 @@ async def get_one_bulletin(bid: int) -> dict:
@router.get("/bulletins/{bid}")
async def api_get_bulletin(
db: DbDependency,
bid: int,
current_user: HttpUser = Depends(get_current_http_user)
):
return await get_one_bulletin(bid)
with db.transaction() as conn:
return await get_one_bulletin(conn, bid)
class CreateBulletinRequest(BaseModel):
subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title")
@@ -84,11 +85,11 @@ class CreateBulletinRequest(BaseModel):
@router.post("/bulletins", status_code=status.HTTP_201_CREATED)
async def create_bulletin(
db: DbDependency,
payload: CreateBulletinRequest,
current_user: HttpUser = Depends(get_current_http_user)
):
from packetserver.runners.http_server import get_db_connection
conn = get_db_connection()
with db.transaction() as conn:
root = conn.root()
if 'bulletins' not in root:
@@ -117,11 +118,13 @@ async def create_bulletin(
@html_router.get("/bulletins", response_class=HTMLResponse)
async def bulletin_list_page(
db: DbDependency,
request: Request,
limit: Optional[int] = Query(50, le=100),
current_user: HttpUser = Depends(get_current_http_user)
):
api_resp = await list_bulletins(limit=limit, since=None)
with db.transaction() as conn:
api_resp = await list_bulletins(conn, limit=limit, since=None)
bulletins = api_resp["bulletins"]
return templates.TemplateResponse(
@@ -132,7 +135,6 @@ async def bulletin_list_page(
@html_router.get("/bulletins/new", response_class=HTMLResponse)
async def bulletin_new_form(
request: Request,
current_user: HttpUser = Depends(get_current_http_user) # require login, consistent with site
):
return templates.TemplateResponse(
"bulletin_new.html",
@@ -141,6 +143,7 @@ async def bulletin_new_form(
@html_router.post("/bulletins/new")
async def bulletin_new_submit(
db: DbDependency,
request: Request,
subject: str = Form(...),
body: str = Form(...),
@@ -152,8 +155,7 @@ async def bulletin_new_submit(
{"request": request, "error": "Subject and body are required."},
status_code=400
)
from packetserver.runners.http_server import get_db_connection
conn = get_db_connection()
with db.transaction() as conn:
root = conn.root()
if 'bulletins' not in root:
@@ -167,17 +169,17 @@ async def bulletin_new_submit(
new_id = new_bulletin.write_new(root)
transaction.commit()
return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303)
@html_router.get("/bulletins/{bid}", response_class=HTMLResponse)
async def bulletin_detail_page(
db: DbDependency,
request: Request,
bid: int = Path(...),
current_user: HttpUser = Depends(get_current_http_user)
):
bulletin = await get_one_bulletin(bid=bid)
with db.transaction() as conn:
bulletin = await get_one_bulletin(conn, bid=bid)
return templates.TemplateResponse(
"bulletin_detail.html",