# coding: utf-8
"""
Slurm job manager. See https://slurm.schedmd.com/quickstart.html.
"""
from __future__ import annotations
__all__ = ["SlurmJobManager", "SlurmJobFileFactory"]
import os
import time
import re
import stat
import pathlib
import shlex
import subprocess
from law.config import Config
from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile
from law.target.file import get_path
from law.util import interruptable_popen, make_list, quote_cmd, parse_duration
from law.logger import get_logger
from law._types import Any, Sequence
logger = get_logger(__name__)
_cfg = Config.instance()
[docs]
class SlurmJobManager(BaseJobManager):
# chunking settings
chunk_size_submit = 0
chunk_size_cancel = _cfg.get_expanded_int("job", "slurm_chunk_size_cancel")
chunk_size_query = _cfg.get_expanded_int("job", "slurm_chunk_size_query")
submission_cre = re.compile(r"^Submitted batch job (\d+)$")
squeue_format = r"JobID,State"
squeue_cre = re.compile(r"^\s*(\d+)\s+([^\s]+)$")
sacct_format = r"JobID,State,ExitCode,Reason"
sacct_cre = re.compile(r"^\s*(\d+)\s+([^\s]+)\s+(-?\d+):-?\d+\s+(.+)$")
def __init__(self, partition: str | None = None, threads: int = 1) -> None:
super().__init__()
self.partition = partition
self.threads = threads
[docs]
def cleanup(self, *args, **kwargs) -> None: # type: ignore[override]
raise NotImplementedError("SlurmJobManager.cleanup is not implemented")
[docs]
def cleanup_batch(self, *args, **kwargs) -> None: # type: ignore[override]
raise NotImplementedError("SlurmJobManager.cleanup_batch is not implemented")
[docs]
def submit( # type: ignore[override]
self,
job_file: str | pathlib.Path,
partition: str | None = None,
retries: int = 0,
retry_delay: float | int = 3,
silent: bool = False,
_processes: list | None = None,
) -> int | None:
# default arguments
if partition is None:
partition = self.partition
# get the job file location as the submission command is run it the same directory
job_file_dir, job_file_name = os.path.split(os.path.abspath(get_path(job_file)))
# build the command
cmd = shlex.split(_cfg.get_expanded("job", "slurm_cmd_sbatch"))
if partition:
cmd += ["--partition", partition]
cmd += [job_file_name]
cmd_str = quote_cmd(cmd)
# define the actual submission in a loop to simplify retries
while True:
# run the command
logger.debug(f"submit slurm job with command '{cmd_str}'")
out: str
err: str
code, out, err = interruptable_popen( # type: ignore[assignment]
cmd_str,
shell=True,
executable="/bin/bash",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=job_file_dir,
kill_timeout=2,
processes=_processes,
)
# get the job id(s)
if code == 0:
# loop through all lines and try to match the expected pattern
for line in out.strip().split("\n"):
m = self.submission_cre.match(line.strip())
if m:
job_id = int(m.group(1))
break
else:
code = 1
err = f"cannot parse slurm job id(s) from output:\n{out}"
# retry or done?
if code == 0:
return job_id
logger.debug(f"submission of slurm job '{job_file}' failed with code {code}:\n{err}")
if retries > 0:
retries -= 1
time.sleep(retry_delay)
continue
if silent:
return None
raise Exception(f"submission of slurm job '{job_file}' failed:\n{err}")
[docs]
def cancel( # type: ignore[override]
self,
job_id: int | Sequence[int],
partition: str | None = None,
silent: bool = False,
_processes: list | None = None,
) -> dict[int, None] | None:
# default arguments
if partition is None:
partition = self.partition
chunking = isinstance(job_id, (list, tuple))
job_ids = make_list(job_id)
# build the command
cmd = shlex.split(_cfg.get_expanded("job", "slurm_cmd_scancel"))
if partition:
cmd += ["--partition", partition]
cmd += job_ids
cmd_str = quote_cmd(cmd)
# run it
logger.debug(f"cancel slurm job(s) with command '{cmd_str}'")
out: str
err: str
code, out, err = interruptable_popen( # type: ignore[assignment]
cmd_str,
shell=True,
executable="/bin/bash",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
kill_timeout=2,
processes=_processes,
)
# check success
if code != 0 and not silent:
raise Exception(
f"cancellation of slurm job(s) '{job_id}' failed with code {code}:\n{err}",
)
return {job_id: None for job_id in job_ids} if chunking else None
[docs]
def query( # type: ignore[override]
self,
job_id: int | Sequence[int],
partition: str | None = None,
silent: bool = False,
_processes: list | None = None,
) -> dict[int, dict[str, Any]] | dict[str, Any] | None:
# default arguments
if partition is None:
partition = self.partition
chunking = isinstance(job_id, (list, tuple))
job_ids = make_list(job_id)
# build the squeue command
cmd = shlex.split(_cfg.get_expanded("job", "slurm_cmd_squeue"))
cmd += ["--format", self.squeue_format, "--noheader"]
if partition:
cmd += ["--partition", partition]
cmd += ["--jobs", ",".join(map(str, job_ids))]
# optionally prepend timeout
query_timeout = _cfg.get_expanded(
"job",
_cfg.find_option("job", "slurm_job_query_timeout", "job_query_timeout"),
)
if query_timeout:
query_timeout_sec = parse_duration(query_timeout, input_unit="s")
cmd = self.prepend_timeout_command(cmd, query_timeout_sec)
# run it
cmd_str = quote_cmd(cmd)
logger.debug(f"query slurm job(s) with command '{cmd_str}'")
out: str
err: str
code, out, err = interruptable_popen( # type: ignore[assignment]
cmd,
shell=True,
executable="/bin/bash",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
kill_timeout=2,
processes=_processes,
)
# special case: when the id of a single yet expired job is queried, squeue responds with an
# error (exit code != 0), so as a workaround, consider these cases as an empty result
if code != 0 and "invalid job id specified" in err.lower():
code = 0
query_data = {}
else:
# handle errors
if code != 0:
if silent:
return None
raise Exception(
f"queue query of slurm job(s) '{job_id}' failed with code {code}:\n{err}",
)
# parse the output and extract the status per job
query_data = self.parse_squeue_output(out)
# some jobs might already be in the accounting history, so query for missing job ids
missing_ids = [_job_id for _job_id in job_ids if _job_id not in query_data]
if missing_ids:
# build the sacct command
cmd = shlex.split(_cfg.get_expanded("job", "slurm_cmd_sacct"))
cmd += ["--format", self.sacct_format, "--noheader"]
if partition:
cmd += ["--partition", partition]
cmd += ["--jobs", ",".join(map(str, missing_ids))]
cmd_str = quote_cmd(cmd)
logger.debug(f"query slurm accounting history with command '{cmd_str}'")
code, out, err = interruptable_popen( # type: ignore[assignment]
cmd,
shell=True,
executable="/bin/bash",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
kill_timeout=2,
processes=_processes,
)
# handle errors
if code != 0:
if silent:
return None
raise Exception(
f"accounting query of slurm job(s) '{job_id}' failed with code {code}:\n{err}",
)
# parse the output and update query data
query_data.update(self.parse_sacct_output(out))
# compare to the requested job ids and perform some checks
for _job_id in job_ids:
if _job_id not in query_data:
if not chunking:
if silent:
return None
raise Exception(f"slurm job(s) '{job_id}' not found in query response")
else:
query_data[_job_id] = self.job_status_dict(
job_id=_job_id,
status=self.FAILED,
error="job not found in query response",
)
return query_data if chunking else query_data[job_id] # type: ignore[index]
@classmethod
def parse_squeue_output(cls, out: str) -> dict[int, dict[str, Any]]:
# retrieve information per block mapped to the job id
query_data = {}
for line in out.strip().split("\n"):
m = cls.squeue_cre.match(line.strip())
if not m:
continue
# build the job id
job_id = int(m.group(1))
# get the job status code
status = cls.map_status(m.group(2))
# store it
query_data[job_id] = cls.job_status_dict(job_id=job_id, status=status)
return query_data
@classmethod
def parse_sacct_output(cls, out: str) -> dict[int, dict[str, Any]]:
# retrieve information per block mapped to the job id
query_data = {}
for line in out.strip().split("\n"):
m = cls.sacct_cre.match(line.strip())
if not m:
continue
# build the job id
job_id = int(m.group(1))
# get the job status code
status = cls.map_status(m.group(2))
# get the exit code
code = int(m.group(3))
# get the error message (if any)
error = m.group(4).strip()
if error == "None":
error = None
# handle inconsistencies between status, code and the presence of an error message
if code != 0 and status != cls.FAILED:
status = cls.FAILED
if not error:
error = f"job status set to '{cls.FAILED}' due to non-zero exit code {code}"
if not error and status == cls.FAILED:
error = m.group(2)
# store it
query_data[job_id] = cls.job_status_dict(
job_id=job_id,
status=status,
code=code,
error=error,
)
return query_data
@classmethod
def map_status(cls, status: str | None) -> str:
# see https://slurm.schedmd.com/squeue.html#lbAG
if isinstance(status, str):
status = status.strip("+")
if status in ["CONFIGURING", "PENDING", "REQUEUED", "REQUEUE_HOLD", "REQUEUE_FED"]:
return cls.PENDING
if status in ["RUNNING", "COMPLETING", "STAGE_OUT"]:
return cls.RUNNING
if status in ["COMPLETED"]:
return cls.FINISHED
if status in [
"BOOT_FAIL", "CANCELLED", "DEADLINE", "FAILED", "NODE_FAIL", "OUT_OF_MEMORY",
"PREEMPTED", "REVOKED", "SPECIAL_EXIT", "STOPPED", "SUSPENDED", "TIMEOUT",
]:
return cls.FAILED
logger.debug(f"unknown slurm job state '{status}'")
return cls.FAILED
[docs]
class SlurmJobFileFactory(BaseJobFileFactory):
config_attrs = BaseJobFileFactory.config_attrs + [
"file_name", "command", "executable", "arguments", "shell", "input_files", "job_name",
"partition", "stdout", "stderr", "postfix_output_files", "custom_content", "absolute_paths",
]
def __init__(
self,
*,
file_name: str = "slurm_job.sh",
command: str | Sequence[str] | None = None,
executable: str | None = None,
arguments: str | Sequence[str] | None = None,
shell: str = "bash",
input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None,
job_name: str | None = None,
partition: str | None = None,
stdout: str = "stdout.txt",
stderr: str = "stderr.txt",
postfix_output_files: bool = True,
custom_content: str | Sequence[str] | None = None,
absolute_paths: bool = False,
**kwargs,
) -> None:
# get some default kwargs from the config
if kwargs.get("dir") is None:
kwargs["dir"] = _cfg.get_expanded(
"job",
_cfg.find_option("job", "slurm_job_file_dir", "job_file_dir"),
)
if kwargs.get("mkdtemp") is None:
kwargs["mkdtemp"] = _cfg.get_expanded_bool(
"job",
_cfg.find_option("job", "slurm_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"),
force_type=False,
)
if kwargs.get("cleanup") is None:
kwargs["cleanup"] = _cfg.get_expanded_bool(
"job",
_cfg.find_option("job", "slurm_job_file_dir_cleanup", "job_file_dir_cleanup"),
)
super().__init__(**kwargs)
self.file_name = file_name
self.command = command
self.executable = executable
self.arguments = arguments
self.shell = shell
self.input_files = input_files or {}
self.job_name = job_name
self.partition = partition
self.stdout = stdout
self.stderr = stderr
self.postfix_output_files = postfix_output_files
self.custom_content = custom_content
self.absolute_paths = absolute_paths
[docs]
def create(
self,
postfix: str | None = None,
**kwargs,
) -> tuple[str, SlurmJobFileFactory.Config]:
# merge kwargs and instance attributes
c = self.get_config(**kwargs)
# some sanity checks
if not c.file_name:
raise ValueError("file_name must not be empty")
if not c.command and not c.executable:
raise ValueError("either command or executable must not be empty")
if not c.shell:
raise ValueError("shell must not be empty")
# postfix certain output files
if c.postfix_output_files:
skip_postfix_cre = re.compile(r"^(/dev/).*$")
skip_postfix = lambda s: bool(skip_postfix_cre.match(s))
for attr in ["stdout", "stderr", "custom_log_file"]:
if c[attr] and not skip_postfix(c[attr]):
c[attr] = self.postfix_output_file(c[attr], postfix)
# ensure that all input files are JobInputFile objects
c.input_files = {
key: JobInputFile(f)
for key, f in c.input_files.items()
}
# ensure that the executable is an input file, remember the key to access it
if c.executable:
executable_keys = [
k
for k, v in c.input_files.items()
if get_path(v) == get_path(c.executable)
]
if executable_keys:
executable_key = executable_keys[0]
else:
executable_key = "executable_file"
c.input_files[executable_key] = JobInputFile(c.executable)
# prepare input files
def prepare_input(f):
# when not copied or forwarded, just return the absolute, original path
abs_path = os.path.abspath(f.path)
if not f.copy or f.forward:
return abs_path
# copy the file
abs_path = self.provide_input(
src=abs_path,
postfix=postfix if f.postfix and not f.share else None,
dir=c.dir,
skip_existing=f.share,
)
return abs_path
# absolute absolute paths
for key, f in c.input_files.items():
f.path_sub_abs = prepare_input(f)
# input paths relative to the submission or initial dir
# forwarded files are skipped as they are not treated as normal inputs
for key, f in c.input_files.items():
if f.forward:
continue
f.path_sub_rel = (
os.path.basename(f.path_sub_abs)
if f.copy and not c.absolute_paths else
f.path_sub_abs
)
# input paths as seen by the job, before and after poptential rendering
for key, f in c.input_files.items():
f.path_job_pre_render = (
f.path_sub_abs
if f.forward else
f.path_sub_rel
)
f.path_job_post_render = (
f.path_sub_abs
if f.forward and not f.render_job else
os.path.basename(f.path_sub_abs)
)
# update files in render variables with version after potential rendering
c.render_variables.update({
key: f.path_job_post_render
for key, f in c.input_files.items()
})
# add space separated input files before potential rendering to render variables
c.render_variables["input_files"] = " ".join(
f.path_job_pre_render
for f in c.input_files.values()
)
# add space separated list of input files for rendering
c.render_variables["input_files_render"] = " ".join(
f.path_job_pre_render
for f in c.input_files.values()
if f.render_job
)
# add the custom log file to render variables
if c.custom_log_file:
c.render_variables["log_file"] = c.custom_log_file
# add the file postfix to render variables
if postfix and "file_postfix" not in c.render_variables:
c.render_variables["file_postfix"] = postfix
# linearize render variables
render_variables = self.linearize_render_variables(c.render_variables)
# prepare the job description file
job_file = self.postfix_input_file(os.path.join(c.dir, str(c.file_name)), postfix)
# render copied input files
for key, f in c.input_files.items():
if not f.copy or f.forward or not f.render_local:
continue
self.render_file(
f.path_sub_abs,
f.path_sub_abs,
render_variables,
postfix=postfix if f.postfix else None,
)
# prepare the executable when given
if c.executable:
c.executable = get_path(c.input_files[executable_key].path_sub_rel)
# make the file executable for the user and group
path = os.path.join(c.dir, os.path.basename(c.executable))
if os.path.exists(path):
os.chmod(path, os.stat(path).st_mode | stat.S_IXUSR | stat.S_IXGRP)
# job file content
content: list[str | tuple[str, Any]] = []
content.append(f"#!/usr/bin/env {c.shell}")
content.append("")
if c.job_name:
content.append(("job-name", c.job_name))
if c.partition:
content.append(("partition", c.partition))
content.append(("output", c.stdout or "NONE"))
content.append(("error", c.stderr or "NONE"))
# add custom content
if c.custom_content:
content += c.custom_content
# write the job file
with open(job_file, "w") as f:
for obj in content:
line = self.create_line(obj)
f.write(f"{line}\n")
# prepare arguments
args = c.arguments or ""
if args:
args = " " + (quote_cmd(args) if isinstance(args, (list, tuple)) else args)
# add the command
if c.command:
cmd = quote_cmd(c.command) if isinstance(c.command, (list, tuple)) else c.command
f.write(f"\n{cmd.strip()}{args}\n")
# add the executable
if c.executable:
cmd = c.executable
f.write(f"\n{cmd}{args}\n")
# make it executable
os.chmod(job_file, os.stat(job_file).st_mode | stat.S_IXUSR | stat.S_IXGRP)
logger.debug(f"created slurm job file at '{job_file}'")
return job_file, c
@classmethod
def create_line(cls, args):
_str = lambda s: str(s).strip()
if not isinstance(args, (list, tuple)):
return args.strip()
if len(args) == 1:
return f"#SBATCH --{_str(args[0])}"
if len(args) == 2:
return f"#SBATCH --{_str(args[0])}={_str(args[1])}"
raise Exception(f"cannot create job file line from '{args}'")