diff --git a/src/packetserver/common/util.py b/src/packetserver/common/util.py index 80d4d6b..14009c3 100644 --- a/src/packetserver/common/util.py +++ b/src/packetserver/common/util.py @@ -1,5 +1,12 @@ import re import datetime +import tempfile +import tarfile +from typing import Union, Iterable, Tuple, Optional +import os.path +from io import BytesIO +import random +import string def email_valid(email: str) -> bool: """Taken from https://www.geeksforgeeks.org/check-if-email-address-valid-or-not-in-python/""" @@ -39,4 +46,50 @@ def from_date_digits(index: str) -> datetime: if len(ind) >= 14: second = int(ind[12:14]) - return datetime.datetime(year, month, day ,hour, minute, second) \ No newline at end of file + return datetime.datetime(year, month, day ,hour, minute, second) + +def tar_bytes(file: Union[str, Iterable]) -> bytes: + """Creates a tar archive in a temporary file with the specified files at root level. + Returns the bytes of the archive.""" + files = [] + if type(file) is str: + files.append(file) + else: + for i in file: + files.append(str(i)) + + with tempfile.TemporaryFile() as temp: + tar_obj = tarfile.TarFile(fileobj=temp, mode="w") + for i in files: + tar_obj.add(i, arcname=os.path.basename(i)) + tar_obj.close() + temp.seek(0) + return temp.read() + +def bytes_to_tar_bytes(name: str, data: bytes) -> bytes: + """Creates a tar archive with a single file of name with bytes as the contents""" + with tempfile.TemporaryFile() as temp: + tar_obj = tarfile.TarFile(fileobj=temp, mode="w") + bio = BytesIO(data) + tar_info = tarfile.TarInfo(name=name) + tar_info.size = len(data) + tar_obj.addfile(tar_info, bio) + tar_obj.close() + temp.seek(0) + return temp.read() + +def extract_tar_bytes(tarfile_bytes: bytes) -> Tuple[str, bytes]: + """Takes the bytes of a tarfile, and returns the name and bytes of the first file in the archive.""" + out_bytes = b'' + bio = BytesIO(tarfile_bytes) + tar_obj = tarfile.TarFile(fileobj=bio, mode="r") + members = tar_obj.getmembers() + for i in range(0, len(members)): + if members[i].isfile(): + return members[i].name, tar_obj.extractfile(members[i]).read() + raise FileNotFoundError("No files found to extract from archive") + +def random_string(length=8) -> str: + rand_str = ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + return rand_str + diff --git a/src/packetserver/runner/podman.py b/src/packetserver/runner/podman.py index 6440d32..6a269d0 100644 --- a/src/packetserver/runner/podman.py +++ b/src/packetserver/runner/podman.py @@ -1,14 +1,19 @@ """Uses podman to run jobs in containers.""" +import time + from . import Runner, Orchestrator, RunnerStatus from collections import namedtuple from typing import Optional, Iterable import subprocess import podman +import podman.errors import os import os.path import logging import ZODB import datetime +from os.path import basename, dirname +from packetserver.common.util import bytes_to_tar_bytes, random_string PodmanOptions = namedtuple("PodmanOptions", ["default_timeout", "max_timeout", "image_name", "max_active_jobs", "container_keepalive", "name_prefix"]) @@ -49,19 +54,67 @@ class PodmanOrchestrator(Orchestrator): def client(self): return podman.PodmanClient(base_url=self.uri) - def refresh_user_db(self, username: str, db: ZODB.DB): + def add_file_to_user_container(self, username: str, data: bytes, path: str): + pass + + def get_file_from_user_container(self, username: str, path: str) -> bytes : pass def podman_start_user_container(self, username: str): - pass + con = self.client.containers.create(self.opts.image_name, name=self.get_container_name(username), + command=["tail", "-f", "/dev/null"]) + con.start() + started_at = datetime.datetime.now() + while con.inspect()['State']['Status'] not in ['exited', 'running']: + now = datetime.datetime.now() + if (now - started_at).total_seconds() > 300: + con.stop() + con.remove() + raise RuntimeError(f"Couldn't start container for user {username}") + time.sleep(.1) + time.sleep(.5) + if con.inspect()['State']['Status'] != 'running': + con.stop() + con.remove() + raise RuntimeError(f"Couldn't start container for user {username}") - def podman_stop_user_container + def podman_remove_container_name(self, container_name: str): + cli = self.client + logging.debug(f"Attempting to remove container named {container_name}") + try: + con = cli.containers.get(container_name) + except podman.errors.exceptions.NotFound as e: + return + try: + con.rename(f"{container_name}_{random_string()}") + except: + pass + if con.inspect()['State']['Status'] != 'exited': + try: + con.stop() + except: + pass + try: + con.remove() + except: + pass + return - def podman_container_exists(self, container_name: str) -> bool: - return False + def podman_stop_user_container(self, username: str): + self.podman_remove_container_name(self.get_container_name(username)) + + def podman_user_container_exists(self, username: str) -> bool: + try: + self.client.containers.get(self.get_container_name(username)) + return True + except podman.errors.exceptions.NotFound: + return False def clean_orphaned_containers(self): - pass + cli = self.client + for i in cli.containers.list(all=True): + if self.opts.name_prefix in str(i.name): + self.podman_remove_container_name(str(i.name)) def get_container_name(self, username: str) -> str: return self.opts.name_prefix + username.lower().strip() @@ -78,7 +131,7 @@ class PodmanOrchestrator(Orchestrator): """Checks running containers and stops them if they have been running too long.""" for c in self.user_containers: if (datetime.datetime.now() - self.user_containers[c]) > self.opts.container_keepalive: - # stop the container TODO + self.podman_remove_container_name(c) del self.user_containers[c] def runners_in_process(self) -> int: @@ -99,6 +152,8 @@ class PodmanOrchestrator(Orchestrator): def new_runner(self, username: str, args: Iterable[str], job_id: int, environment: Optional[dict] = None, timeout_secs: str = 300, refresh_db: bool = True, labels: Optional[list] = None) -> Optional[Runner]: + if not self.started: + return None with self.runner_lock: if not self.runners_available(): return None diff --git a/src/packetserver/server/db.py b/src/packetserver/server/db.py index f3ffbba..d61257f 100644 --- a/src/packetserver/server/db.py +++ b/src/packetserver/server/db.py @@ -2,6 +2,7 @@ import ZODB import json import gzip import base64 +from io import BytesIO def get_user_db(username: str, db: ZODB.DB) -> dict: udb = {