Source code for scalable.slurm

import os

from collections.abc import Awaitable
from distributed.core import Status
from distributed.deploy.spec import ProcessInterface

from .common import logger
from .core import Job, JobQueueCluster, cluster_parameters, job_parameters
from .support import *
from .utilities import *

DEFAULT_REQUEST_QUANTITY = 1
RECOVERY_DELAY = 3

class SlurmJob(Job):

    # Override class variables
    cancel_command = "scancel"

    def __init__(
        self,
        scheduler=None,
        name=None,
        queue=None,
        account=None,
        walltime=None,
        container=None,
        comm_port=None,
        tag=None,
        hardware=None,
        logs_location=None,
        shared_lock=None,
        worker_env_vars=None,
        active_job_ids=None,
        **base_class_kwargs
    ):
        super().__init__(
            scheduler=scheduler, name=name, hardware=hardware, comm_port=comm_port, \
            container=container, tag=tag, shared_lock=shared_lock, **base_class_kwargs
        )

        job_name = f"{self.name}-job"

        self.slurm_cmd = salloc_command(account=account, name=job_name, nodes=DEFAULT_REQUEST_QUANTITY, 
                                        partition=queue, time=walltime)
        
        self.job_name = job_name
        self.active_job_ids = active_job_ids
        self.job_id = None
        self.job_node = None
        self.log_file = None
        self.deleted = False

        if logs_location is not None:
            self.log_file = os.path.abspath(os.path.join(logs_location, f"{self.name}-{self.tag}.log"))
        
        apptainer_version = os.getenv("APPTAINER_VERSION", None)

        self.send_command = apptainer_module_command(apptainer_version) + [";"] + \
                            self.container.get_command(worker_env_vars) + self.command_args
    
    async def _srun_command(self, command):
        prefix = ["srun", f"--jobid={self.job_id}"]
        command = prefix + command
        out = await self._run_command(command)
        return out
    
    async def _ssh_command(self, command):
        prefix = ["ssh", self.job_node]
        suffix = []
        if self.log_file:
            suffix = [f">> {self.log_file}", "2>&1"]
        suffix.append("&")
        command = command + suffix
        command = list(map(str, command))
        command_str = " ".join(command)
        command = prefix + [f"\"{command_str}\""]
        out = await self._run_command(command)
        return out
    
    async def _check_valid_job_id(self, job_id):
        out = await self._run_command(jobcheck_command(job_id))
        match = re.search(self.job_id_regexp, out)
        return match

    async def start(self):
        """Start function for the worker.

        The worker sets itself up by requesting or consuming necessary 
        resources and adding itself as an active worker to the cluster. 
        All cases such as there being no active or available nodes are handled 
        by this function. Called by the parent classes when scaling the workers.
        """
        logger.debug("Starting worker: %s", self.name)

        async with self.shared_lock:
            while self.job_id is None:
                self.job_node = self.hardware.get_available_node(self.cpus, self.memory)
                if self.job_node is None:
                    break
                job_id = self.hardware.get_node_jobid(self.job_node)
                match = await self._check_valid_job_id(job_id)
                if match is None:
                    self.hardware.remove_jobid_nodes(job_id)
                else:
                    self.job_id = match.groupdict().get("job_id")
            if self.job_node == None:
                out = await self._run_command(self.slurm_cmd)
                job_id = await self._run_command(jobid_command(self.job_name))
                self.job_id = job_id
                self.active_job_ids.append(job_id)
                nodelist = await self._run_command(nodelist_command(self.job_name))
                nodes = parse_nodelist(nodelist)
                worker_memories = await self._srun_command(memory_command())
                worker_cpus = await self._srun_command(core_command())
                worker_memories = worker_memories.split('\n')
                worker_cpus = worker_cpus.split('\n')
                for index in range(0, len(nodes)):
                    node = nodes[index]
                    alloc_memory = int(worker_memories[index])
                    alloc_cpus = int(worker_cpus[index])
                    if not self.hardware.assign_resources(node=node, cpus=alloc_cpus, memory=alloc_memory, jobid=self.job_id):
                        stored_job_id = self.hardware.get_node_jobid(node)
                        match = await self._check_valid_job_id(stored_job_id)
                        if match is None:
                            self.hardware.remove_jobid_nodes(stored_job_id)
                            assert self.hardware.assign_resources(node=node, cpus=alloc_cpus, memory=alloc_memory, jobid=self.job_id)
                        else:
                            raise ValueError(f"Node {node} is already assigned to job {stored_job_id}")
                self.job_node = self.hardware.get_available_node(self.cpus, self.memory)
            _ = await self._ssh_command(self.send_command)
            asyncio.get_event_loop().create_task(self.check_launched_worker())
            self.hardware.utilize_resources(self.job_node, self.cpus, self.memory, self.job_id)
            self.launched.append((self.name, self.tag))
        
        await ProcessInterface.start(self)

    async def close(self):
        """Close function for the worker.
        
        The worker releases the resources it was utilizing and removes itself."""
        if self.deleted:
            return
        async with self.shared_lock:
            if self.hardware.is_assigned(self.job_id):
                match = await self._check_valid_job_id(self.job_id)
                if match is None:
                    self.hardware.remove_jobid_nodes(self.job_id)
                else:
                    self.hardware.release_resources(self.job_node, self.cpus, self.memory, self.job_id)
                    if not self.hardware.has_active_nodes(self.job_id):
                        self.hardware.remove_jobid_nodes(self.job_id)
                        await SlurmJob._close_job(self.job_id, self.cancel_command, self.comm_port)
            cluster = self._cluster()
            if self.tag in self.removed and self.removed[self.tag] > 0:
                self.removed[self.tag] -= 1
                self.deleted = True
            elif cluster.status not in (Status.closing, Status.closed):
                cluster.loop.call_later(RECOVERY_DELAY, cluster._correct_state)

[docs] class SlurmCluster(JobQueueCluster): __doc__ = """Launch Dask on a SLURM cluster. Inherits the JobQueueCluster class. Parameters ---------- {cluster} *args : tuple Positional arguments to pass to JobQueueCluster. **kwargs : dict Keyword arguments to pass to JobQueueCluster. """.format( cluster=cluster_parameters ) job_cls = SlurmJob def close(self, timeout: float | None = None) -> Awaitable[None] | None: """Close the cluster This closes all running jobs and the scheduler. Pending jobs belonging to the user are also cancelled.""" active_jobs = self.active_job_ids for job_id in active_jobs: cancel_job_command = ["scancel", job_id] result = asyncio.run(self.job_cls._call(cancel_job_command, self.comm_port)) result = None if result == "" else result if result is None: self.hardware.remove_jobid_nodes(job_id) logger.info(f"Cancelled job: {job_id}") else: logger.error(f"Failed to cancel job: {result}") return super().close(timeout) @staticmethod def set_default_request_quantity(nodes): """Set the default number of nodes to request when scaling the cluster. Static Function. Does not require an instance of the class. If set to 1 (the original default), the cluster will request one hardware node at a time when scaling. If set to a higher number, like 5, the cluster will request 5 hardware nodes at a time when scaling. This is helpful when each worker may need almost all the resources of a node and it is more efficient to request multiple nodes at once. Parameters ---------- nodes : int Number of nodes to request when scaling the cluster. """ global DEFAULT_REQUEST_QUANTITY DEFAULT_REQUEST_QUANTITY = nodes