from dask.typing import no_default
from distributed import Client
from distributed.diagnostics.plugin import SchedulerPlugin
from .common import logger
from .slurm import SlurmCluster
class SlurmSchedulerPlugin(SchedulerPlugin):
def __init__(self, cluster):
self.cluster = cluster
super().__init__()
[docs]
class ScalableClient(Client):
"""Client for submitting tasks to a Dask cluster. Inherits the dask
client object.
Parameters
----------
cluster : Cluster
The cluster object to connect to for submitting tasks.
"""
[docs]
def __init__(self, cluster, *args, **kwargs):
super().__init__(address = cluster, *args, **kwargs)
if isinstance(cluster, SlurmCluster):
self.register_scheduler_plugin(SlurmSchedulerPlugin(None))
def submit(self, func, *args, tag=None, n=1, **kwargs):
"""Submit a function to be ran by workers in the cluster.
Parameters
----------
func : function
Function to be scheduled for execution.
*args : tuple
Optional positional arguments to pass to the function.
tag : str (optional)
User-defined tag for the container that can run func. If not
provided, func is assigned to be ran on a random container.
n : int (default 1)
Number of workers needed to run this task. Meant to be used with
tag. Multiple workers can be useful for application level
distributed computing.
**kwargs : dict (optional)
Optional key-value pairs to be passed to the function.
Examples
--------
>>> c = client.submit(add, a, b)
Returns
-------
Future
Returns the future object that runs the function.
Raises
------
TypeError
If 'func' is not callable, a TypeError is raised.
ValueError
If 'allow_other_workers'is True and 'workers' is None, a
ValueError is raised.
"""
resources = None
if tag is not None:
resources = {tag: n}
return super().submit(func, resources=resources, *args, **kwargs)
def cancel(self, futures, *args, **kwargs):
"""
Cancel running futures
This stops future tasks from being scheduled if they have not yet run
and deletes them if they have already run. After calling, this result
and all dependent results will no longer be accessible
Parameters
----------
futures : future | future, list
One or more futures to cancel (as a list).
*args : tuple
Positional arguments to pass to dask client's cancel method.
**kwargs : dict
Keyword arguments to pass to dask client's cancel method.
"""
return super().cancel(futures, *args, **kwargs)
def close(self, timeout=no_default):
"""Close this client
Clients will also close automatically when your Python session ends
Parameters
----------
timeout : number
Time in seconds after which to raise a
``dask.distributed.TimeoutError``
"""
return super().close(timeout)
def map(self, func, *parameters, tag, n, **kwargs):
"""Map a function on multiple sets of arguments to run the function
multiple times with different inputs.
Parameters
----------
func : function
Function to be scheduled for execution.
parameters : list of lists
Lists of parameters to be passed to the function. The first list
should have the first parameter values, the second list should have
the second parameter values, and so on. The lists should be of the
same length.
tag : str (optional)
User-defined tag for the container that can run func. If not
provided, func is assigned to be ran on a random container.
n : int (default 1)
Number of workers needed to run this task. Meant to be used with
tag. Multiple workers can be useful for application level
distributed computing.
*args : tuple
Positional arguments to pass to dask client's map method.
**kwargs : dict
Keyword arguments to pass to dask client's map method.
Examples
--------
>>> def add(a, b): ...
>>> L = client.map(add, [[1, 2, 3], [4, 5, 6]])
Returns
-------
List of futures
Returns a list of future objects, each for a separate run of the
function with the given parameters.
"""
resources = None
if tag is not None:
resources = {tag: n}
return super().map(func, *parameters, resources=resources, **kwargs)
def get_versions(self, check=False, packages = None):
"""Return version info for the scheduler, all workers and myself
Parameters
----------
check : bool
Raise ValueError if all required & optional packages do not match.
Default is False.
packages : list
Extra package names to check.
Examples
--------
>>> c.get_versions()
>>> c.get_versions(packages=['sklearn', 'geopandas'])
"""
return super().get_versions(check, packages)