Redoing all the database logic entirely. I didn't understand FastAPI going into this and let grok do weird stuff. We did weird stuff together, really.
This commit is contained in:
@@ -7,6 +7,7 @@ import time
|
||||
from persistent.mapping import PersistentMapping
|
||||
from persistent.list import PersistentList
|
||||
from packetserver.common.util import is_valid_ax25_callsign
|
||||
from .database import get_db, get_transaction
|
||||
|
||||
ph = PasswordHasher()
|
||||
|
||||
@@ -50,14 +51,15 @@ class HttpUser(Persistent):
|
||||
# rf enabled checks..
|
||||
#
|
||||
|
||||
def is_rf_enabled(self, connection) -> bool:
|
||||
def is_rf_enabled(self) -> bool:
|
||||
"""
|
||||
Check if RF gateway is enabled (i.e., callsign NOT in global blacklist).
|
||||
Requires an open ZODB connection.
|
||||
"""
|
||||
root = connection.root()
|
||||
blacklist = root.get('config', {}).get('blacklist', [])
|
||||
return self.username not in blacklist
|
||||
with get_transaction() as storage:
|
||||
root = storage.root()
|
||||
blacklist = root.get('config', {}).get('blacklist', [])
|
||||
return self.username not in blacklist
|
||||
|
||||
def set_rf_enabled(self, connection, allow: bool):
|
||||
"""
|
||||
@@ -67,26 +69,26 @@ class HttpUser(Persistent):
|
||||
"""
|
||||
from packetserver.common.util import is_valid_ax25_callsign # our validator
|
||||
|
||||
root = connection.root()
|
||||
config = root.setdefault('config', PersistentMapping())
|
||||
blacklist = config.setdefault('blacklist', PersistentList())
|
||||
with get_transaction() as storage:
|
||||
root = storage.root()
|
||||
config = root.setdefault('config', PersistentMapping())
|
||||
blacklist = config.setdefault('blacklist', PersistentList())
|
||||
|
||||
upper_name = self.username
|
||||
upper_name = self.username
|
||||
|
||||
if allow:
|
||||
if not is_valid_ax25_callsign(upper_name):
|
||||
raise ValueError(f"{upper_name} is not a valid AX.25 callsign – cannot enable RF access")
|
||||
if upper_name in blacklist:
|
||||
blacklist.remove(upper_name)
|
||||
blacklist._p_changed = True
|
||||
else:
|
||||
if upper_name not in blacklist:
|
||||
blacklist.append(upper_name)
|
||||
blacklist._p_changed = True
|
||||
if allow:
|
||||
if not is_valid_ax25_callsign(upper_name):
|
||||
raise ValueError(f"{upper_name} is not a valid AX.25 callsign – cannot enable RF access")
|
||||
if upper_name in blacklist:
|
||||
blacklist.remove(upper_name)
|
||||
blacklist._p_changed = True
|
||||
else:
|
||||
if upper_name not in blacklist:
|
||||
blacklist.append(upper_name)
|
||||
blacklist._p_changed = True
|
||||
|
||||
config._p_changed = True
|
||||
root._p_changed = True
|
||||
# Caller should commit the transaction
|
||||
config._p_changed = True
|
||||
root._p_changed = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Password handling (unchanged)
|
||||
|
||||
18
packetserver/http/config.py
Normal file
18
packetserver/http/config.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
Application settings loaded from environment variables and .env files.
|
||||
"""
|
||||
# Define your settings fields with type hints and optional default values
|
||||
name: str = "PacketServer"
|
||||
zeo_file: str
|
||||
operator: str | None = None
|
||||
debug_mode: bool = False
|
||||
log_level: str = "info"
|
||||
|
||||
# Configure how settings are loaded
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=False, # Make environment variable names case-sensitive
|
||||
env_prefix="PS_APP_" # Use a prefix for environment variables (e.g., MY_APP_DATABASE_URL)
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
from .config import Settings
|
||||
from os.path import isfile
|
||||
import ZEO
|
||||
import ZODB
|
||||
from fastapi import Depends
|
||||
from typing import Annotated, ContextManager
|
||||
|
||||
settings = Settings()
|
||||
|
||||
def get_zeo_address(zeo_address_file: str) -> tuple[str,int]:
|
||||
if not isfile(zeo_address_file):
|
||||
raise FileNotFoundError(f"ZEO address file is not a file: '{zeo_address_file}'")
|
||||
|
||||
contents = open(zeo_address_file, 'r').read().strip().split(":")
|
||||
|
||||
if len(contents) != 2:
|
||||
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
|
||||
|
||||
host = contents[0]
|
||||
try:
|
||||
port = int(contents[1])
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid ZEO address file: {zeo_address_file}")
|
||||
return host,port
|
||||
|
||||
def get_db() -> ZODB.DB:
|
||||
return ZEO.DB(get_zeo_address(settings.zeo_file))
|
||||
|
||||
def get_transaction() -> ContextManager:
|
||||
return ZEO.DB(get_zeo_address(settings.zeo_file)).transaction()
|
||||
|
||||
DbDependency = Annotated[ZODB.DB, Depends(get_db)]
|
||||
TransactionDependency = Annotated[ContextManager, Depends(get_transaction)]
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
|
||||
from .auth import HttpUser
|
||||
|
||||
from .database import get_transaction
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
@@ -13,41 +13,40 @@ async def get_current_http_user(credentials: HTTPBasicCredentials = Depends(secu
|
||||
Authenticate via Basic Auth using HttpUser from ZODB.
|
||||
Injected by the standalone runner (get_db_connection available).
|
||||
"""
|
||||
from packetserver.runners.http_server import get_db_connection # provided by runner
|
||||
|
||||
conn = get_db_connection()
|
||||
root = conn.root()
|
||||
with get_transaction() as conn:
|
||||
root = conn.root()
|
||||
|
||||
http_users = root.get("httpUsers")
|
||||
if http_users is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
http_users = root.get("httpUsers")
|
||||
if http_users is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
user: HttpUser | None = http_users.get(credentials.username.upper())
|
||||
user: HttpUser | None = http_users.get(credentials.username.upper())
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
if not user.http_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="HTTP access disabled for this user",
|
||||
)
|
||||
if not user.http_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="HTTP access disabled for this user",
|
||||
)
|
||||
|
||||
if not user.verify_password(credentials.password):
|
||||
user.record_login_failure()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
if not user.verify_password(credentials.password):
|
||||
user.record_login_failure()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
user.record_login_success()
|
||||
return user
|
||||
user.record_login_success()
|
||||
return user
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
import transaction
|
||||
from persistent.list import PersistentList
|
||||
|
||||
from ..database import DbDependency, TransactionDependency
|
||||
from ..dependencies import get_current_http_user
|
||||
from ..auth import HttpUser
|
||||
from ..server import templates
|
||||
@@ -20,31 +21,30 @@ html_router = APIRouter(tags=["bulletins-html"])
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
async def list_bulletins(limit: int = 50, since: Optional[datetime] = None) -> dict:
|
||||
from packetserver.runners.http_server import get_db_connection
|
||||
conn = get_db_connection()
|
||||
root = conn.root()
|
||||
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
||||
async def list_bulletins(, limit: int = 50, since: Optional[datetime] = None) -> dict:
|
||||
with trans as conn:
|
||||
root = conn.root()
|
||||
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
||||
|
||||
# Newest first
|
||||
bulletins_list = sorted(bulletins_list, key=lambda b: b.created_at, reverse=True)
|
||||
# Newest first
|
||||
bulletins_list = sorted(bulletins_list, key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
if since:
|
||||
bulletins_list = [b for b in bulletins_list if b.created_at > since]
|
||||
if since:
|
||||
bulletins_list = [b for b in bulletins_list if b.created_at > since]
|
||||
|
||||
bulletins = [
|
||||
{
|
||||
"id": b.id,
|
||||
"author": b.author,
|
||||
"subject": b.subject,
|
||||
"body": b.body,
|
||||
"created_at": b.created_at.isoformat() + "Z",
|
||||
"updated_at": b.updated_at.isoformat() + "Z",
|
||||
}
|
||||
for b in bulletins_list[:limit]
|
||||
]
|
||||
bulletins = [
|
||||
{
|
||||
"id": b.id,
|
||||
"author": b.author,
|
||||
"subject": b.subject,
|
||||
"body": b.body,
|
||||
"created_at": b.created_at.isoformat() + "Z",
|
||||
"updated_at": b.updated_at.isoformat() + "Z",
|
||||
}
|
||||
for b in bulletins_list[:limit]
|
||||
]
|
||||
|
||||
return {"bulletins": bulletins}
|
||||
return {"bulletins": bulletins}
|
||||
|
||||
@router.get("/bulletins")
|
||||
async def api_list_bulletins(
|
||||
@@ -55,22 +55,21 @@ async def api_list_bulletins(
|
||||
return await list_bulletins(limit=limit, since=since)
|
||||
|
||||
async def get_one_bulletin(bid: int) -> dict:
|
||||
from packetserver.runners.http_server import get_db_connection
|
||||
conn = get_db_connection()
|
||||
root = conn.root()
|
||||
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
||||
with get_transaction() as conn:
|
||||
root = conn.root()
|
||||
bulletins_list: List[Bulletin] = root.get("bulletins", [])
|
||||
|
||||
for b in bulletins_list:
|
||||
if b.id == bid:
|
||||
return {
|
||||
"id": b.id,
|
||||
"author": b.author,
|
||||
"subject": b.subject,
|
||||
"body": b.body,
|
||||
"created_at": b.created_at.isoformat() + "Z",
|
||||
"updated_at": b.updated_at.isoformat() + "Z",
|
||||
}
|
||||
raise HTTPException(status_code=404, detail="Bulletin not found")
|
||||
for b in bulletins_list:
|
||||
if b.id == bid:
|
||||
return {
|
||||
"id": b.id,
|
||||
"author": b.author,
|
||||
"subject": b.subject,
|
||||
"body": b.body,
|
||||
"created_at": b.created_at.isoformat() + "Z",
|
||||
"updated_at": b.updated_at.isoformat() + "Z",
|
||||
}
|
||||
raise HTTPException(status_code=404, detail="Bulletin not found")
|
||||
|
||||
@router.get("/bulletins/{bid}")
|
||||
async def api_get_bulletin(
|
||||
|
||||
@@ -18,50 +18,6 @@ import ZODB.DB
|
||||
import logging
|
||||
from packetserver.http.server import app
|
||||
|
||||
# Global DB and connection for reuse in the FastAPI dependency
|
||||
_db = None
|
||||
_connection = None
|
||||
|
||||
|
||||
def open_database(db_arg: str) -> ZODB.DB:
|
||||
"""
|
||||
Open a ZODB database from either a local FileStorage path or ZEO address.
|
||||
"""
|
||||
if ":" in db_arg:
|
||||
parts = db_arg.split(":")
|
||||
if len(parts) == 2 and parts[1].isdigit():
|
||||
import ZEO
|
||||
host = parts[0]
|
||||
port = int(parts[1])
|
||||
storage = ZEO.client((host, port)) # correct modern ZEO client function
|
||||
return ZODB.DB(storage)
|
||||
|
||||
# Local FileStorage fallback
|
||||
storage = ZODB.FileStorage.FileStorage(db_arg)
|
||||
return ZODB.DB(storage)
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
"""Helper used in http/server.py dependency (get_current_http_user)"""
|
||||
global _connection
|
||||
if _connection is None or getattr(_connection, "opened", None) is None:
|
||||
if _db is None:
|
||||
raise RuntimeError("Database not opened – run the runner properly")
|
||||
_connection = _db.open()
|
||||
return _connection
|
||||
|
||||
def get_db():
|
||||
"""Helper used in http/server.py dependency (get_current_http_user)"""
|
||||
if _db is None:
|
||||
raise RuntimeError("Database not opened – run the runner properly")
|
||||
return _db
|
||||
|
||||
|
||||
# Monkey-patch the dependency helper so server.py can use it without changes
|
||||
from packetserver.http import server
|
||||
server.get_db_connection = get_db_connection # replaces any previous definition
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run the PacketServer HTTP API server")
|
||||
parser.add_argument("--db", required=True, help="DB path (local /path/to/Data.fs) or ZEO (host:port)")
|
||||
@@ -71,30 +27,12 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
global _db
|
||||
try:
|
||||
_db = open_database(args.db)
|
||||
print(f"Opened database: {args.db}")
|
||||
except Exception as e:
|
||||
print(f"Failed to open database {args.db}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Open initial connection (will be reused/closed on shutdown)
|
||||
get_db_connection()
|
||||
|
||||
try:
|
||||
uvicorn.run(
|
||||
"packetserver.http.server:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
)
|
||||
finally:
|
||||
if _connection and not _connection.closed:
|
||||
_connection.close()
|
||||
if _db:
|
||||
_db.close()
|
||||
|
||||
uvicorn.run(
|
||||
"packetserver.http.server:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,3 +14,4 @@ jinja2
|
||||
python-multipart
|
||||
argon2-cffi
|
||||
pydantic
|
||||
pydantic_settings
|
||||
Reference in New Issue
Block a user