Suggestions from grok about database.py. Updated bulletins to use new database logic.
This commit is contained in:
@@ -1,33 +1,65 @@
|
|||||||
from .config import Settings
|
from fastapi import Depends
|
||||||
|
from typing import Annotated, Generator
|
||||||
from os.path import isfile
|
from os.path import isfile
|
||||||
|
|
||||||
import ZEO
|
import ZEO
|
||||||
import ZODB
|
import ZODB
|
||||||
from fastapi import Depends
|
from ZODB.Connection import Connection
|
||||||
from typing import Annotated, ContextManager
|
import transaction
|
||||||
|
|
||||||
|
from .config import Settings # assuming Settings has zeo_file: str
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
def get_zeo_address(zeo_address_file: str) -> tuple[str,int]:
|
# Global shared DB instance (created once)
|
||||||
|
_db: ZODB.DB | None = None
|
||||||
|
|
||||||
|
def _get_zeo_address(zeo_address_file: str) -> tuple[str, int]:
|
||||||
if not isfile(zeo_address_file):
|
if not isfile(zeo_address_file):
|
||||||
raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'")
|
raise FileNotFoundError(f"ZEO address file not found: '{zeo_address_file}'")
|
||||||
|
|
||||||
contents = open(zeo_address_file, 'r').read().strip().split(":")
|
contents = open(zeo_address_file, 'r').read().strip().split(":")
|
||||||
|
|
||||||
if len(contents) != 2:
|
if len(contents) != 2:
|
||||||
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
|
raise ValueError(f"Invalid ZEO address format in {zeo_address_file}")
|
||||||
|
|
||||||
host = contents[0]
|
host = contents[0]
|
||||||
try:
|
try:
|
||||||
port = int(contents[1])
|
port = int(contents[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
|
raise ValueError(f"Invalid port in ZEO address file: {zeo_address_file}")
|
||||||
return host,port
|
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
def init_db() -> ZODB.DB:
|
||||||
|
"""Call this on app startup to create the shared DB instance."""
|
||||||
|
global _db
|
||||||
|
if _db is not None:
|
||||||
|
return _db
|
||||||
|
|
||||||
|
host, port = _get_zeo_address(settings.zeo_file)
|
||||||
|
storage = ZEO.ClientStorage((host, port))
|
||||||
|
_db = ZODB.DB(storage)
|
||||||
|
return _db
|
||||||
|
|
||||||
def get_db() -> ZODB.DB:
|
def get_db() -> ZODB.DB:
|
||||||
return ZEO.DB(get_zeo_address(settings.zeo_file))
|
"""Dependency for the shared DB instance (e.g., for class methods needing DB)."""
|
||||||
|
if _db is None:
|
||||||
|
raise RuntimeError("Database not initialized – call init_db() on startup")
|
||||||
|
return _db
|
||||||
|
|
||||||
def get_transaction() -> ContextManager:
|
def get_connection() -> Generator[Connection, None, None]:
|
||||||
return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction()
|
"""Per-request dependency: yields an open Connection, closes on exit."""
|
||||||
|
db = get_db()
|
||||||
|
conn = db.open()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Optional: per-request transaction (if you want automatic commit/abort)
|
||||||
|
def get_transaction_manager():
|
||||||
|
return transaction.manager
|
||||||
|
|
||||||
|
# Annotated dependencies for routers
|
||||||
DbDependency = Annotated[ZODB.DB, Depends(get_db)]
|
DbDependency = Annotated[ZODB.DB, Depends(get_db)]
|
||||||
TransactionDependency = Annotated[ContextManager, Depends(get_transaction)]
|
ConnectionDependency = Annotated[Connection, Depends(get_connection)]
|
||||||
@@ -7,7 +7,7 @@ import transaction
|
|||||||
from persistent.list import PersistentList
|
from persistent.list import PersistentList
|
||||||
from ZODB.Connection import Connection
|
from ZODB.Connection import Connection
|
||||||
|
|
||||||
from ..database import DbDependency, TransactionDependency
|
from packetserver.http.database import DbDependency, ConnectionDependency, get_db
|
||||||
from ..dependencies import get_current_http_user
|
from ..dependencies import get_current_http_user
|
||||||
from ..auth import HttpUser
|
from ..auth import HttpUser
|
||||||
from ..server import templates
|
from ..server import templates
|
||||||
@@ -48,35 +48,36 @@ async def list_bulletins(connection: Connection, limit: int = 50, since: Optiona
|
|||||||
|
|
||||||
@router.get("/bulletins")
|
@router.get("/bulletins")
|
||||||
async def api_list_bulletins(
|
async def api_list_bulletins(
|
||||||
|
db: DbDependency,
|
||||||
limit: Optional[int] = Query(50, le=100),
|
limit: Optional[int] = Query(50, le=100),
|
||||||
since: Optional[datetime] = None,
|
since: Optional[datetime] = None,
|
||||||
current_user: HttpUser = Depends(get_current_http_user)
|
|
||||||
):
|
):
|
||||||
return await list_bulletins(limit=limit, since=since)
|
with db.transaction() as conn:
|
||||||
|
return await list_bulletins(conn, limit=limit, since=since)
|
||||||
|
|
||||||
async def get_one_bulletin(bid: int) -> dict:
|
async def get_one_bulletin(connection: Connection, bid: int) -> dict:
|
||||||
with get_transaction() as conn:
|
root = connection.root()
|
||||||
root = conn.root()
|
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
||||||
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
|
||||||
|
|
||||||
for b in bulletins_list:
|
for b in bulletins_list:
|
||||||
if b.id == bid:
|
if b.id == bid:
|
||||||
return {
|
return {
|
||||||
"id": b.id,
|
"id": b.id,
|
||||||
"author": b.author,
|
"author": b.author,
|
||||||
"subject": b.subject,
|
"subject": b.subject,
|
||||||
"body": b.body,
|
"body": b.body,
|
||||||
"created_at": b.created_at.isoformat() + "Z",
|
"created_at": b.created_at.isoformat() + "Z",
|
||||||
"updated_at": b.updated_at.isoformat() + "Z",
|
"updated_at": b.updated_at.isoformat() + "Z",
|
||||||
}
|
}
|
||||||
raise HTTPException(status_code=404, detail="Bulletin not found")
|
raise HTTPException(status_code=404, detail="Bulletin not found")
|
||||||
|
|
||||||
@router.get("/bulletins/{bid}")
|
@router.get("/bulletins/{bid}")
|
||||||
async def api_get_bulletin(
|
async def api_get_bulletin(
|
||||||
|
db: DbDependency,
|
||||||
bid: int,
|
bid: int,
|
||||||
current_user: HttpUser = Depends(get_current_http_user)
|
|
||||||
):
|
):
|
||||||
return await get_one_bulletin(bid)
|
with db.transaction() as conn:
|
||||||
|
return await get_one_bulletin(conn, bid)
|
||||||
|
|
||||||
class CreateBulletinRequest(BaseModel):
|
class CreateBulletinRequest(BaseModel):
|
||||||
subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title")
|
subject: constr(min_length=1, max_length=100) = Field(..., description="Bulletin subject/title")
|
||||||
@@ -84,55 +85,56 @@ class CreateBulletinRequest(BaseModel):
|
|||||||
|
|
||||||
@router.post("/bulletins", status_code=status.HTTP_201_CREATED)
|
@router.post("/bulletins", status_code=status.HTTP_201_CREATED)
|
||||||
async def create_bulletin(
|
async def create_bulletin(
|
||||||
|
db: DbDependency,
|
||||||
payload: CreateBulletinRequest,
|
payload: CreateBulletinRequest,
|
||||||
current_user: HttpUser = Depends(get_current_http_user)
|
current_user: HttpUser = Depends(get_current_http_user)
|
||||||
):
|
):
|
||||||
from packetserver.runners.http_server import get_db_connection
|
with db.transaction() as conn:
|
||||||
conn = get_db_connection()
|
root = conn.root()
|
||||||
root = conn.root()
|
|
||||||
|
|
||||||
if 'bulletins' not in root:
|
if 'bulletins' not in root:
|
||||||
root['bulletins'] = PersistentList()
|
root['bulletins'] = PersistentList()
|
||||||
|
|
||||||
new_bulletin = Bulletin(
|
new_bulletin = Bulletin(
|
||||||
author=current_user.username,
|
author=current_user.username,
|
||||||
subject=payload.subject.strip(),
|
subject=payload.subject.strip(),
|
||||||
text=payload.body.strip()
|
text=payload.body.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
new_id = new_bulletin.write_new(root)
|
new_id = new_bulletin.write_new(root)
|
||||||
|
|
||||||
transaction.commit()
|
transaction.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": new_id,
|
"id": new_id,
|
||||||
"author": new_bulletin.author,
|
"author": new_bulletin.author,
|
||||||
"subject": new_bulletin.subject,
|
"subject": new_bulletin.subject,
|
||||||
"body": new_bulletin.body,
|
"body": new_bulletin.body,
|
||||||
"created_at": new_bulletin.created_at.isoformat() + "Z",
|
"created_at": new_bulletin.created_at.isoformat() + "Z",
|
||||||
"updated_at": new_bulletin.updated_at.isoformat() + "Z",
|
"updated_at": new_bulletin.updated_at.isoformat() + "Z",
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- HTML Pages (require login) ---
|
# --- HTML Pages (require login) ---
|
||||||
|
|
||||||
@html_router.get("/bulletins", response_class=HTMLResponse)
|
@html_router.get("/bulletins", response_class=HTMLResponse)
|
||||||
async def bulletin_list_page(
|
async def bulletin_list_page(
|
||||||
|
db: DbDependency,
|
||||||
request: Request,
|
request: Request,
|
||||||
limit: Optional[int] = Query(50, le=100),
|
limit: Optional[int] = Query(50, le=100),
|
||||||
current_user: HttpUser = Depends(get_current_http_user)
|
current_user: HttpUser = Depends(get_current_http_user)
|
||||||
):
|
):
|
||||||
api_resp = await list_bulletins(limit=limit, since=None)
|
with db.transaction() as conn:
|
||||||
bulletins = api_resp["bulletins"]
|
api_resp = await list_bulletins(conn, limit=limit, since=None)
|
||||||
|
bulletins = api_resp["bulletins"]
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"bulletin_list.html",
|
"bulletin_list.html",
|
||||||
{"request": request, "bulletins": bulletins, "current_user": current_user.username}
|
{"request": request, "bulletins": bulletins, "current_user": current_user.username}
|
||||||
)
|
)
|
||||||
|
|
||||||
@html_router.get("/bulletins/new", response_class=HTMLResponse)
|
@html_router.get("/bulletins/new", response_class=HTMLResponse)
|
||||||
async def bulletin_new_form(
|
async def bulletin_new_form(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: HttpUser = Depends(get_current_http_user) # require login, consistent with site
|
|
||||||
):
|
):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"bulletin_new.html",
|
"bulletin_new.html",
|
||||||
@@ -141,6 +143,7 @@ async def bulletin_new_form(
|
|||||||
|
|
||||||
@html_router.post("/bulletins/new")
|
@html_router.post("/bulletins/new")
|
||||||
async def bulletin_new_submit(
|
async def bulletin_new_submit(
|
||||||
|
db: DbDependency,
|
||||||
request: Request,
|
request: Request,
|
||||||
subject: str = Form(...),
|
subject: str = Form(...),
|
||||||
body: str = Form(...),
|
body: str = Form(...),
|
||||||
@@ -152,34 +155,33 @@ async def bulletin_new_submit(
|
|||||||
{"request": request, "error": "Subject and body are required."},
|
{"request": request, "error": "Subject and body are required."},
|
||||||
status_code=400
|
status_code=400
|
||||||
)
|
)
|
||||||
from packetserver.runners.http_server import get_db_connection
|
with db.transaction() as conn:
|
||||||
conn = get_db_connection()
|
root = conn.root()
|
||||||
root = conn.root()
|
|
||||||
|
|
||||||
if 'bulletins' not in root:
|
if 'bulletins' not in root:
|
||||||
root['bulletins'] = PersistentList()
|
root['bulletins'] = PersistentList()
|
||||||
|
|
||||||
new_bulletin = Bulletin(
|
new_bulletin = Bulletin(
|
||||||
author=current_user.username,
|
author=current_user.username,
|
||||||
subject=subject.strip(),
|
subject=subject.strip(),
|
||||||
text=body.strip()
|
text=body.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
new_id = new_bulletin.write_new(root)
|
new_id = new_bulletin.write_new(root)
|
||||||
|
|
||||||
transaction.commit()
|
return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303)
|
||||||
|
|
||||||
return RedirectResponse(url=f"/bulletins/{new_id}", status_code=303)
|
|
||||||
|
|
||||||
@html_router.get("/bulletins/{bid}", response_class=HTMLResponse)
|
@html_router.get("/bulletins/{bid}", response_class=HTMLResponse)
|
||||||
async def bulletin_detail_page(
|
async def bulletin_detail_page(
|
||||||
|
db: DbDependency,
|
||||||
request: Request,
|
request: Request,
|
||||||
bid: int = Path(...),
|
bid: int = Path(...),
|
||||||
current_user: HttpUser = Depends(get_current_http_user)
|
current_user: HttpUser = Depends(get_current_http_user)
|
||||||
):
|
):
|
||||||
bulletin = await get_one_bulletin(bid=bid)
|
with db.transaction() as conn:
|
||||||
|
bulletin = await get_one_bulletin(conn, bid=bid)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"bulletin_detail.html",
|
"bulletin_detail.html",
|
||||||
{"request": request, "bulletin": bulletin, "current_user": current_user.username}
|
{"request": request, "bulletin": bulletin, "current_user": current_user.username}
|
||||||
)
|
)
|
||||||
Reference in New Issue
Block a user