diff --git a/packetserver/http/database.py b/packetserver/http/database.py index 6b0bf9b..a6a7590 100644 --- a/packetserver/http/database.py +++ b/packetserver/http/database.py @@ -37,8 +37,7 @@ def init_db() -> ZODB.DB: return _db host, port = _get_zeo_address(settings.zeo_file) - storage = ZEO.ClientStorage((host, port)) - _db = ZODB.DB(storage) + _db = ZEO.DB((host, port)) return _db def get_db() -> ZODB.DB: @@ -54,7 +53,9 @@ def get_connection() -> Generator[Connection, None, None]: try: yield conn finally: - conn.close() + #print("not closing connection") + #conn.close() + pass # Optional: per-request transaction (if you want automatic commit/abort) def get_transaction_manager(): diff --git a/packetserver/http/dependencies.py b/packetserver/http/dependencies.py index d5be69d..d3670ab 100644 --- a/packetserver/http/dependencies.py +++ b/packetserver/http/dependencies.py @@ -3,50 +3,49 @@ from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from .auth import HttpUser -from .database import get_transaction +from .database import ConnectionDependency security = HTTPBasic() -async def get_current_http_user(credentials: HTTPBasicCredentials = Depends(security)): +async def get_current_http_user(conn: ConnectionDependency, credentials: HTTPBasicCredentials = Depends(security)): """ Authenticate via Basic Auth using HttpUser from ZODB. Injected by the standalone runner (get_db_connection available). """ - with get_transaction() as conn: - root = conn.root() + root = conn.root() - http_users = root.get("httpUsers") - if http_users is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password", - headers={"WWW-Authenticate": "Basic"}, - ) + http_users = root.get("httpUsers") + if http_users is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid username or password", + headers={"WWW-Authenticate": "Basic"}, + ) - user: HttpUser | None = http_users.get(credentials.username.upper()) + user: HttpUser | None = http_users.get(credentials.username.upper()) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password", - headers={"WWW-Authenticate": "Basic"}, - ) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid username or password", + headers={"WWW-Authenticate": "Basic"}, + ) - if not user.http_enabled: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="HTTP access disabled for this user", - ) + if not user.http_enabled: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="HTTP access disabled for this user", + ) - if not user.verify_password(credentials.password): - user.record_login_failure() - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password", - headers={"WWW-Authenticate": "Basic"}, - ) + if not user.verify_password(credentials.password): + user.record_login_failure() + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid username or password", + headers={"WWW-Authenticate": "Basic"}, + ) - user.record_login_success() - return user \ No newline at end of file + user.record_login_success() + return user \ No newline at end of file diff --git a/packetserver/http/routers/dashboard.py b/packetserver/http/routers/dashboard.py index bc60540..e351f88 100644 --- a/packetserver/http/routers/dashboard.py +++ b/packetserver/http/routers/dashboard.py @@ -5,6 +5,7 @@ from fastapi.responses import HTMLResponse from packetserver.http.dependencies import get_current_http_user from packetserver.http.auth import HttpUser from packetserver.http.server import templates +from packetserver.http.database import ConnectionDependency router = APIRouter(tags=["dashboard"]) @@ -15,11 +16,13 @@ from .bulletins import list_bulletins @router.get("/dashboard", response_class=HTMLResponse) async def dashboard( + conn: ConnectionDependency, request: Request, current_user: HttpUser = Depends(get_current_http_user) ): # Internal call – pass explicit defaults to avoid Query object injection messages_resp = await api_get_messages( + conn, current_user=current_user, type="all", limit=100, @@ -27,7 +30,7 @@ async def dashboard( ) messages = messages_resp["messages"] - bulletins_resp = await list_bulletins(limit=10, since=None) + bulletins_resp = await list_bulletins(conn, limit=10, since=None) recent_bulletins = bulletins_resp["bulletins"] return templates.TemplateResponse( @@ -42,11 +45,12 @@ async def dashboard( @router.get("/dashboard/profile", response_class=HTMLResponse) async def profile_page( + conn: ConnectionDependency, request: Request, current_user: HttpUser = Depends(get_current_http_user) ): from packetserver.http.routers.profile import profile as api_profile - profile_data = await api_profile(current_user=current_user) + profile_data = await api_profile(conn, current_user=current_user) return templates.TemplateResponse( "profile.html", diff --git a/packetserver/http/routers/messages.py b/packetserver/http/routers/messages.py index e1cdd2e..6a855bc 100644 --- a/packetserver/http/routers/messages.py +++ b/packetserver/http/routers/messages.py @@ -161,6 +161,7 @@ async def mark_message_retrieved( @html_router.get("/messages", response_class=HTMLResponse) async def message_list_page( + conn: ConnectionDependency, request: Request, type: str = Query("received", alias="msg_type"), # matches your filter links limit: Optional[int] = Query(50, le=100), @@ -168,7 +169,7 @@ async def message_list_page( ): from packetserver.http.server import templates # Directly call the existing API endpoint function - api_resp = await get_messages(current_user=current_user, type=type, limit=limit, since=None) + api_resp = await get_messages(conn, current_user=current_user, type=type, limit=limit, since=None) messages = api_resp["messages"] return templates.TemplateResponse( diff --git a/packetserver/http/routers/objects.py b/packetserver/http/routers/objects.py index c17c3bb..8174e77 100644 --- a/packetserver/http/routers/objects.py +++ b/packetserver/http/routers/objects.py @@ -6,6 +6,7 @@ import mimetypes from packetserver.http.dependencies import get_current_http_user from packetserver.http.auth import HttpUser +from packetserver.http.database import DbDependency, ConnectionDependency from packetserver.server.objects import Object from pydantic import BaseModel @@ -22,15 +23,11 @@ class ObjectSummary(BaseModel): modified_at: datetime @router.get("/objects", response_model=List[ObjectSummary]) -async def list_my_objects(current_user: HttpUser = Depends(get_current_http_user)): - from packetserver.runners.http_server import get_db_connection - - conn = get_db_connection() - root = conn.root() +async def list_my_objects(db: DbDependency, current_user: HttpUser = Depends(get_current_http_user)): username = current_user.username # uppercase callsign - core_objects = Object.get_objects_by_username(username, root) + core_objects = Object.get_objects_by_username(username, db) # Sort newest first by created_at core_objects.sort(key=lambda o: o.created_at, reverse=True) diff --git a/packetserver/http/server.py b/packetserver/http/server.py index 1fc8687..2eecf5b 100644 --- a/packetserver/http/server.py +++ b/packetserver/http/server.py @@ -4,6 +4,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pathlib import Path +from .database import init_db from .routers import public, profile, messages, send BASE_DIR = Path(__file__).parent.resolve() @@ -40,6 +41,9 @@ from .routers.message_detail import router as message_detail_router from .routers.messages import html_router from .routers.objects import router as objects_router +# initialize database +init_db() + # Include routers app.include_router(public.router) app.include_router(profile.router) diff --git a/packetserver/runners/http_server.py b/packetserver/runners/http_server.py index 6cc2ccb..15e0388 100644 --- a/packetserver/runners/http_server.py +++ b/packetserver/runners/http_server.py @@ -13,14 +13,10 @@ import argparse import sys import uvicorn -import ZODB.FileStorage -import ZODB.DB -import logging from packetserver.http.server import app def main(): parser = argparse.ArgumentParser(description="Run the PacketServer HTTP API server") - parser.add_argument("--db", required=True, help="DB path (local /path/to/Data.fs) or ZEO (host:port)") parser.add_argument("--host", default="0.0.0.0", help="Bind host (default: 0.0.0.0)") parser.add_argument("--port", type=int, default=8080, help="Port to listen on (default: 8080)") parser.add_argument("--reload", action="store_true", help="Enable auto-reload during development")