diff --git a/src/packetserver/client/jobs.py b/src/packetserver/client/jobs.py index 45ce421..44508ce 100644 --- a/src/packetserver/client/jobs.py +++ b/src/packetserver/client/jobs.py @@ -75,11 +75,18 @@ class JobWrapper: def __repr__(self): return f"" -def send_job(client: Client, bbs_callsign: str, cmd: Union[str, list]) -> int: +def send_job(client: Client, bbs_callsign: str, cmd: Union[str, list], db: bool = False, env: dict = None, + files: dict = None) -> int: """Send a job using client to bbs_callsign with args cmd. Return remote job_id.""" req = Request.blank() req.path = "job" req.payload = {'cmd': cmd} + if db: + req.payload['db'] = '' + if env is not None: + req.payload['env']= env + if files is not None: + req.payload['files'] = files req.method = Request.Method.POST response = client.send_receive_callsign(req, bbs_callsign) if response.status_code != 201: @@ -96,7 +103,7 @@ def get_job_id(client: Client, bbs_callsign: str, job_id: int, get_data=True) -> return JobWrapper(response.payload) class JobSession: - def __init__(self, client: Client, bbs_callsign: str, default_timeout: int = 300, stutter: int = 3): + def __init__(self, client: Client, bbs_callsign: str, default_timeout: int = 300, stutter: int = 1): self.client = client self.bbs = bbs_callsign self.timeout = default_timeout @@ -105,14 +112,14 @@ class JobSession: def connect(self) -> PacketServerConnection: return self.client.new_connection(self.bbs) - def send(self, cmd: Union[str, list]) -> int: - return send_job(self.client, self.bbs, cmd) + def send(self, cmd: Union[str, list], db: bool = False, env: dict = None, files: dict = None) -> int: + return send_job(self.client, self.bbs, cmd, db=db, env=env, files=files) def get_id(self, jid: int) -> JobWrapper: return get_job_id(self.client, self.bbs, jid) - def run_job(self, cmd: Union[str, list]) -> JobWrapper: - jid = self.send(cmd) + def run_job(self, cmd: Union[str, list], db: bool = False, env: dict = None, files: dict = None) -> JobWrapper: + jid = self.send(cmd, db=db, env=env, files=files) time.sleep(self.stutter) j = self.get_id(jid) while not j.is_finished: diff --git a/src/packetserver/server/jobs.py b/src/packetserver/server/jobs.py index 839e1bd..891a9f8 100644 --- a/src/packetserver/server/jobs.py +++ b/src/packetserver/server/jobs.py @@ -9,6 +9,7 @@ from typing import Self,Union,Optional,Tuple from traceback import format_exc from packetserver.common import PacketServerConnection, Request, Response, Message, send_response, send_blank_response from packetserver.common.constants import no_values +from packetserver.server.db import get_user_db_json import ZODB from persistent.list import PersistentList import logging @@ -17,7 +18,7 @@ import gzip import tarfile import json from packetserver.runner.podman import TarFileExtractor, PodmanOrchestrator, PodmanRunner, PodmanOptions -from packetserver.runner import Orchestrator, Runner, RunnerStatus +from packetserver.runner import Orchestrator, Runner, RunnerStatus, RunnerFile from enum import Enum from io import BytesIO import base64 @@ -103,11 +104,19 @@ class Job(persistent.Persistent): def get_next_queued_job(cls, db_root: PersistentMapping) -> Self: return db_root['job_queue'][0] - def __init__(self, cmd: Union[list[str], str], owner: Optional[str] = None, timeout: int = 300): + def __init__(self, cmd: Union[list[str], str], owner: Optional[str] = None, timeout: int = 300, + env: dict = None, files: list[RunnerFile] = None): self.owner = None if owner is not None: self.owner = str(owner).upper().strip() self.cmd = cmd + self.env = {} + if env is not None: + for key in env: + self.env[key] = env[key] + self.files = [] + if files is not None: + self.files = files self.created_at = datetime.datetime.now(datetime.UTC) self.started_at = None self.finished_at = None @@ -254,7 +263,23 @@ def handle_new_job_post(req: Request, conn: PacketServerConnection, db: ZODB.DB) if type(req.payload['cmd']) not in [str, list]: send_blank_response(conn, req, 401, "job post must contain cmd key containing str or list[str]") return - job = Job(req.payload['cmd'], owner=username) + files = [] + if 'db' in req.payload: + logging.debug(f"Fetching a user db as requested.") + dbf = RunnerFile('user-db.json.gz', data=get_user_db_json(username.lower(), db)) + files.append(dbf) + if 'files' in req.payload: + if type(files) is dict: + for key in req.payload['files']: + val = req.payload['files'][key] + if type(val) is bytes: + files.append(RunnerFile(key, data=val)) + env = {} + if 'env' in req.payload: + if type(req.payload['env']) is dict: + for key in req.payload['env']: + env[key] = req.payload[key] + job = Job(req.payload['cmd'], owner=username, env=env, files=files) with db.transaction() as storage: try: new_jid = job.queue(storage.root())