Adding db/env/files support to jobs.

This commit is contained in:
Michael Woods
2025-02-15 19:40:58 -05:00
parent 3ce312e3ff
commit 3ba97dd09f
2 changed files with 41 additions and 9 deletions

View File

@@ -75,11 +75,18 @@ class JobWrapper:
def __repr__(self): def __repr__(self):
return f"<Job {self.id} - {self.owner} - {self.status}>" return f"<Job {self.id} - {self.owner} - {self.status}>"
def send_job(client: Client, bbs_callsign: str, cmd: Union[str, list]) -> int: def send_job(client: Client, bbs_callsign: str, cmd: Union[str, list], db: bool = False, env: dict = None,
files: dict = None) -> int:
"""Send a job using client to bbs_callsign with args cmd. Return remote job_id.""" """Send a job using client to bbs_callsign with args cmd. Return remote job_id."""
req = Request.blank() req = Request.blank()
req.path = "job" req.path = "job"
req.payload = {'cmd': cmd} req.payload = {'cmd': cmd}
if db:
req.payload['db'] = ''
if env is not None:
req.payload['env']= env
if files is not None:
req.payload['files'] = files
req.method = Request.Method.POST req.method = Request.Method.POST
response = client.send_receive_callsign(req, bbs_callsign) response = client.send_receive_callsign(req, bbs_callsign)
if response.status_code != 201: if response.status_code != 201:
@@ -96,7 +103,7 @@ def get_job_id(client: Client, bbs_callsign: str, job_id: int, get_data=True) ->
return JobWrapper(response.payload) return JobWrapper(response.payload)
class JobSession: class JobSession:
def __init__(self, client: Client, bbs_callsign: str, default_timeout: int = 300, stutter: int = 3): def __init__(self, client: Client, bbs_callsign: str, default_timeout: int = 300, stutter: int = 1):
self.client = client self.client = client
self.bbs = bbs_callsign self.bbs = bbs_callsign
self.timeout = default_timeout self.timeout = default_timeout
@@ -105,14 +112,14 @@ class JobSession:
def connect(self) -> PacketServerConnection: def connect(self) -> PacketServerConnection:
return self.client.new_connection(self.bbs) return self.client.new_connection(self.bbs)
def send(self, cmd: Union[str, list]) -> int: def send(self, cmd: Union[str, list], db: bool = False, env: dict = None, files: dict = None) -> int:
return send_job(self.client, self.bbs, cmd) return send_job(self.client, self.bbs, cmd, db=db, env=env, files=files)
def get_id(self, jid: int) -> JobWrapper: def get_id(self, jid: int) -> JobWrapper:
return get_job_id(self.client, self.bbs, jid) return get_job_id(self.client, self.bbs, jid)
def run_job(self, cmd: Union[str, list]) -> JobWrapper: def run_job(self, cmd: Union[str, list], db: bool = False, env: dict = None, files: dict = None) -> JobWrapper:
jid = self.send(cmd) jid = self.send(cmd, db=db, env=env, files=files)
time.sleep(self.stutter) time.sleep(self.stutter)
j = self.get_id(jid) j = self.get_id(jid)
while not j.is_finished: while not j.is_finished:

View File

@@ -9,6 +9,7 @@ from typing import Self,Union,Optional,Tuple
from traceback import format_exc from traceback import format_exc
from packetserver.common import PacketServerConnection, Request, Response, Message, send_response, send_blank_response from packetserver.common import PacketServerConnection, Request, Response, Message, send_response, send_blank_response
from packetserver.common.constants import no_values from packetserver.common.constants import no_values
from packetserver.server.db import get_user_db_json
import ZODB import ZODB
from persistent.list import PersistentList from persistent.list import PersistentList
import logging import logging
@@ -17,7 +18,7 @@ import gzip
import tarfile import tarfile
import json import json
from packetserver.runner.podman import TarFileExtractor, PodmanOrchestrator, PodmanRunner, PodmanOptions from packetserver.runner.podman import TarFileExtractor, PodmanOrchestrator, PodmanRunner, PodmanOptions
from packetserver.runner import Orchestrator, Runner, RunnerStatus from packetserver.runner import Orchestrator, Runner, RunnerStatus, RunnerFile
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
import base64 import base64
@@ -103,11 +104,19 @@ class Job(persistent.Persistent):
def get_next_queued_job(cls, db_root: PersistentMapping) -> Self: def get_next_queued_job(cls, db_root: PersistentMapping) -> Self:
return db_root['job_queue'][0] return db_root['job_queue'][0]
def __init__(self, cmd: Union[list[str], str], owner: Optional[str] = None, timeout: int = 300): def __init__(self, cmd: Union[list[str], str], owner: Optional[str] = None, timeout: int = 300,
env: dict = None, files: list[RunnerFile] = None):
self.owner = None self.owner = None
if owner is not None: if owner is not None:
self.owner = str(owner).upper().strip() self.owner = str(owner).upper().strip()
self.cmd = cmd self.cmd = cmd
self.env = {}
if env is not None:
for key in env:
self.env[key] = env[key]
self.files = []
if files is not None:
self.files = files
self.created_at = datetime.datetime.now(datetime.UTC) self.created_at = datetime.datetime.now(datetime.UTC)
self.started_at = None self.started_at = None
self.finished_at = None self.finished_at = None
@@ -254,7 +263,23 @@ def handle_new_job_post(req: Request, conn: PacketServerConnection, db: ZODB.DB)
if type(req.payload['cmd']) not in [str, list]: if type(req.payload['cmd']) not in [str, list]:
send_blank_response(conn, req, 401, "job post must contain cmd key containing str or list[str]") send_blank_response(conn, req, 401, "job post must contain cmd key containing str or list[str]")
return return
job = Job(req.payload['cmd'], owner=username) files = []
if 'db' in req.payload:
logging.debug(f"Fetching a user db as requested.")
dbf = RunnerFile('user-db.json.gz', data=get_user_db_json(username.lower(), db))
files.append(dbf)
if 'files' in req.payload:
if type(files) is dict:
for key in req.payload['files']:
val = req.payload['files'][key]
if type(val) is bytes:
files.append(RunnerFile(key, data=val))
env = {}
if 'env' in req.payload:
if type(req.payload['env']) is dict:
for key in req.payload['env']:
env[key] = req.payload[key]
job = Job(req.payload['cmd'], owner=username, env=env, files=files)
with db.transaction() as storage: with db.transaction() as storage:
try: try:
new_jid = job.queue(storage.root()) new_jid = job.queue(storage.root())