Bunch of fixes for new database model.

This commit is contained in:
Michael Woods
2025-12-25 20:02:47 -05:00
parent bc8a649ff4
commit bec626678e
10 changed files with 246 additions and 233 deletions

View File

@@ -8,7 +8,7 @@ import time
from persistent.mapping import PersistentMapping from persistent.mapping import PersistentMapping
from persistent.list import PersistentList from persistent.list import PersistentList
from packetserver.common.util import is_valid_ax25_callsign from packetserver.common.util import is_valid_ax25_callsign
from .database import ConnectionDependency from .database import DbDependency
ph = PasswordHasher() ph = PasswordHasher()
@@ -52,23 +52,24 @@ class HttpUser(Persistent):
# rf enabled checks.. # rf enabled checks..
# #
def is_rf_enabled(self, conn: ConnectionDependency) -> bool: def is_rf_enabled(self, db: DbDependency) -> bool:
""" """
Check if RF gateway is enabled (i.e., callsign NOT in global blacklist). Check if RF gateway is enabled (i.e., callsign NOT in global blacklist).
Requires an open ZODB connection. Requires an open ZODB connection.
""" """
with db.transaction() as conn:
root = conn.root() root = conn.root()
blacklist = root.get('config', {}).get('blacklist', []) blacklist = root.get('config', {}).get('blacklist', [])
return self.username not in blacklist return self.username not in blacklist
def set_rf_enabled(self, conn: ConnectionDependency, allow: bool): def set_rf_enabled(self, db: DbDependency, allow: bool):
""" """
Enable/disable RF gateway by adding/removing from global blacklist. Enable/disable RF gateway by adding/removing from global blacklist.
Requires an open ZODB connection (inside a transaction). Requires an open ZODB connection (inside a transaction).
Only allows enabling if the username is a valid AX.25 callsign. Only allows enabling if the username is a valid AX.25 callsign.
""" """
from packetserver.common.util import is_valid_ax25_callsign # our validator from packetserver.common.util import is_valid_ax25_callsign # our validator
with db.transaction() as conn:
root = conn.root() root = conn.root()
config = root.setdefault('config', PersistentMapping()) config = root.setdefault('config', PersistentMapping())
blacklist = config.setdefault('blacklist', PersistentList()) blacklist = config.setdefault('blacklist', PersistentList())

View File

@@ -46,16 +46,16 @@ def get_db() -> ZODB.DB:
raise RuntimeError("Database not initialized call init_db() on startup") raise RuntimeError("Database not initialized call init_db() on startup")
return _db return _db
def get_connection() -> Generator[Connection, None, None]: #def get_connection() -> Generator[Connection, None, None]:
"""Per-request dependency: yields an open Connection, closes on exit.""" # """Per-request dependency: yields an open Connection, closes on exit."""
db = get_db() # db = get_db()
conn = db.open() # conn = db.open()
try: # try:
yield conn # yield conn
finally: # finally:
#print("not closing connection") # #print("not closing connection")
#conn.close() # #conn.close()
pass # pass
# Optional: per-request transaction (if you want automatic commit/abort) # Optional: per-request transaction (if you want automatic commit/abort)
def get_transaction_manager(): def get_transaction_manager():
@@ -63,4 +63,4 @@ def get_transaction_manager():
# Annotated dependencies for routers # Annotated dependencies for routers
DbDependency = Annotated[ZODB.DB, Depends(get_db)] DbDependency = Annotated[ZODB.DB, Depends(get_db)]
ConnectionDependency = Annotated[Connection, Depends(get_connection)] #ConnectionDependency = Annotated[Connection, Depends(get_connection)]

View File

@@ -3,17 +3,17 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from .auth import HttpUser from .auth import HttpUser
from .database import ConnectionDependency from .database import DbDependency
security = HTTPBasic() security = HTTPBasic()
async def get_current_http_user(conn: ConnectionDependency, credentials: HTTPBasicCredentials = Depends(security)): async def get_current_http_user(db: DbDependency, credentials: HTTPBasicCredentials = Depends(security)):
""" """
Authenticate via Basic Auth using HttpUser from ZODB. Authenticate via Basic Auth using HttpUser from ZODB.
Injected by the standalone runner (get_db_connection available). Injected by the standalone runner (get_db_connection available).
""" """
with db.transaction() as conn:
root = conn.root() root = conn.root()
http_users = root.get("httpUsers") http_users = root.get("httpUsers")

View File

@@ -7,7 +7,7 @@ import transaction
from persistent.list import PersistentList from persistent.list import PersistentList
from ZODB.Connection import Connection from ZODB.Connection import Connection
from packetserver.http.database import DbDependency, ConnectionDependency, get_db from packetserver.http.database import DbDependency
from ..dependencies import get_current_http_user from ..dependencies import get_current_http_user
from ..auth import HttpUser from ..auth import HttpUser
from ..server import templates from ..server import templates

View File

@@ -5,7 +5,7 @@ from fastapi.responses import HTMLResponse
from packetserver.http.dependencies import get_current_http_user from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.http.server import templates from packetserver.http.server import templates
from packetserver.http.database import ConnectionDependency from packetserver.http.database import DbDependency
router = APIRouter(tags=["dashboard"]) router = APIRouter(tags=["dashboard"])
@@ -16,18 +16,20 @@ from .bulletins import list_bulletins
@router.get("/dashboard", response_class=HTMLResponse) @router.get("/dashboard", response_class=HTMLResponse)
async def dashboard( async def dashboard(
conn: ConnectionDependency, db: DbDependency,
request: Request, request: Request,
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
# Internal call pass explicit defaults to avoid Query object injection
messages_resp = await api_get_messages( messages_resp = await api_get_messages(
conn, db,
current_user=current_user, current_user=current_user,
type="all", type="all",
limit=100, limit=100,
since=None # prevents Query wrapper since=None # prevents Query wrapper
) )
with db.transaction() as conn:
# Internal call pass explicit defaults to avoid Query object injection
messages = messages_resp["messages"] messages = messages_resp["messages"]
bulletins_resp = await list_bulletins(conn, limit=10, since=None) bulletins_resp = await list_bulletins(conn, limit=10, since=None)
@@ -45,12 +47,12 @@ async def dashboard(
@router.get("/dashboard/profile", response_class=HTMLResponse) @router.get("/dashboard/profile", response_class=HTMLResponse)
async def profile_page( async def profile_page(
conn: ConnectionDependency, db: DbDependency,
request: Request, request: Request,
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
from packetserver.http.routers.profile import profile as api_profile from packetserver.http.routers.profile import profile as api_profile
profile_data = await api_profile(conn, current_user=current_user) profile_data = await api_profile(db, current_user=current_user)
return templates.TemplateResponse( return templates.TemplateResponse(
"profile.html", "profile.html",

View File

@@ -4,11 +4,13 @@ from fastapi.responses import HTMLResponse
from packetserver.http.dependencies import get_current_http_user from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.http.server import templates from packetserver.http.server import templates
from packetserver.http.database import DbDependency
router = APIRouter(tags=["message-detail"]) router = APIRouter(tags=["message-detail"])
@router.get("/dashboard/message/{msg_id}", response_class=HTMLResponse) @router.get("/dashboard/message/{msg_id}", response_class=HTMLResponse)
async def message_detail_page( async def message_detail_page(
db: DbDependency,
request: Request, request: Request,
msg_id: str = Path(..., description="Message UUID as string"), msg_id: str = Path(..., description="Message UUID as string"),
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
@@ -18,6 +20,7 @@ async def message_detail_page(
# Call with mark_retrieved=True to auto-mark as read on view (optional—remove if you prefer manual) # Call with mark_retrieved=True to auto-mark as read on view (optional—remove if you prefer manual)
message_data = await api_get_message( message_data = await api_get_message(
db,
msg_id=msg_id, msg_id=msg_id,
mark_retrieved=True, mark_retrieved=True,
current_user=current_user current_user=current_user

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel, Field, validator
from packetserver.http.dependencies import get_current_http_user from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.http.database import ConnectionDependency from packetserver.http.database import DbDependency
html_router = APIRouter(tags=["messages-html"]) html_router = APIRouter(tags=["messages-html"])
@@ -29,7 +29,7 @@ class MarkRetrievedRequest(BaseModel):
@router.get("/messages") @router.get("/messages")
async def get_messages( async def get_messages(
conn: ConnectionDependency, db: DbDependency,
current_user: HttpUser = Depends(get_current_http_user), current_user: HttpUser = Depends(get_current_http_user),
type: str = Query("received", description="received, sent, or all"), type: str = Query("received", description="received, sent, or all"),
limit: Optional[int] = Query(20, le=100, description="Max messages to return (default 20, max 100)"), limit: Optional[int] = Query(20, le=100, description="Max messages to return (default 20, max 100)"),
@@ -40,6 +40,7 @@ async def get_messages(
limit = 20 limit = 20
username = current_user.username username = current_user.username
with db.transaction() as conn:
root = conn.root() root = conn.root()
if 'messages' not in root: if 'messages' not in root:
@@ -81,11 +82,12 @@ async def get_messages(
@router.get("/messages/{msg_id}") @router.get("/messages/{msg_id}")
async def get_message( async def get_message(
conn: ConnectionDependency, db: DbDependency,
msg_id: str = Path(..., description="UUID of the message (as string)"), msg_id: str = Path(..., description="UUID of the message (as string)"),
mark_retrieved: bool = Query(False, description="If true, mark message as retrieved/read"), mark_retrieved: bool = Query(False, description="If true, mark message as retrieved/read"),
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
with db.transaction() as conn:
root = conn.root() root = conn.root()
username = current_user.username username = current_user.username
@@ -126,11 +128,12 @@ async def get_message(
@router.patch("/messages/{msg_id}") @router.patch("/messages/{msg_id}")
async def mark_message_retrieved( async def mark_message_retrieved(
conn: ConnectionDependency, db: DbDependency,
msg_id: str = Path(..., description="Message UUID as string"), msg_id: str = Path(..., description="Message UUID as string"),
payload: MarkRetrievedRequest = None, payload: MarkRetrievedRequest = None,
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
with db.transaction() as conn:
root = conn.root() root = conn.root()
username = current_user.username username = current_user.username
@@ -161,7 +164,7 @@ async def mark_message_retrieved(
@html_router.get("/messages", response_class=HTMLResponse) @html_router.get("/messages", response_class=HTMLResponse)
async def message_list_page( async def message_list_page(
conn: ConnectionDependency, db: DbDependency,
request: Request, request: Request,
type: str = Query("received", alias="msg_type"), # matches your filter links type: str = Query("received", alias="msg_type"), # matches your filter links
limit: Optional[int] = Query(50, le=100), limit: Optional[int] = Query(50, le=100),
@@ -169,7 +172,7 @@ async def message_list_page(
): ):
from packetserver.http.server import templates from packetserver.http.server import templates
# Directly call the existing API endpoint function # Directly call the existing API endpoint function
api_resp = await get_messages(conn, current_user=current_user, type=type, limit=limit, since=None) api_resp = await get_messages(db, current_user=current_user, type=type, limit=limit, since=None)
messages = api_resp["messages"] messages = api_resp["messages"]
return templates.TemplateResponse( return templates.TemplateResponse(

View File

@@ -6,7 +6,7 @@ import mimetypes
from packetserver.http.dependencies import get_current_http_user from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.http.database import DbDependency, ConnectionDependency from packetserver.http.database import DbDependency
from packetserver.server.objects import Object from packetserver.server.objects import Object
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -3,21 +3,23 @@ from fastapi import APIRouter, Depends
from packetserver.http.dependencies import get_current_http_user from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.http.database import ConnectionDependency from packetserver.http.database import DbDependency
router = APIRouter(prefix="/api/v1", tags=["auth"]) router = APIRouter(prefix="/api/v1", tags=["auth"])
@router.get("/profile") @router.get("/profile")
async def profile(conn: ConnectionDependency,current_user: HttpUser = Depends(get_current_http_user)): async def profile(db: DbDependency, current_user: HttpUser = Depends(get_current_http_user)):
username = current_user.username username = current_user.username
root = conn.root() rf_enabled = current_user.is_rf_enabled(db)
# Get main BBS User and safe dict # Get main BBS User and safe dict
with db.transaction() as conn:
root = conn.root()
main_users = root.get('users', {}) main_users = root.get('users', {})
bbs_user = main_users.get(username) bbs_user = main_users.get(username)
safe_profile = bbs_user.to_safe_dict() if bbs_user else {} safe_profile = bbs_user.to_safe_dict() if bbs_user else {}
rf_enabled = current_user.is_rf_enabled(conn)
return { return {
**safe_profile, **safe_profile,

View File

@@ -11,7 +11,7 @@ from packetserver.http.dependencies import get_current_http_user
from packetserver.http.auth import HttpUser from packetserver.http.auth import HttpUser
from packetserver.server.messages import Message from packetserver.server.messages import Message
from packetserver.common.util import is_valid_ax25_callsign from packetserver.common.util import is_valid_ax25_callsign
from packetserver.http.database import ConnectionDependency from packetserver.http.database import DbDependency
router = APIRouter(prefix="/api/v1", tags=["messages"]) router = APIRouter(prefix="/api/v1", tags=["messages"])
@@ -40,13 +40,15 @@ class SendMessageRequest(BaseModel):
@router.post("/messages") @router.post("/messages")
async def send_message( async def send_message(
conn: ConnectionDependency, db: DbDependency,
payload: SendMessageRequest, payload: SendMessageRequest,
current_user: HttpUser = Depends(get_current_http_user) current_user: HttpUser = Depends(get_current_http_user)
): ):
is_rf_enabled = current_user.is_rf_enabled(db)
with db.transaction() as conn:
root = conn.root() root = conn.root()
if not current_user.is_rf_enabled(conn): if not is_rf_enabled:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="RF gateway access required to send messages" detail="RF gateway access required to send messages"