Source code for rhg_compute_tools.utils

import collections.abc
import dis
import functools
import inspect
import itertools
import json
import os
import queue
import threading
import types
import git

import dask.distributed as dd
import numpy as np
import toolz


[docs]def expand(func): """ Decorator to expand an (args, kwargs) tuple in function calls Intended for use with the :py:func:`collapse` function Parameters ---------- func : function Function to have arguments expanded. Func can have any number of positional and keyword arguments. Returns ------- wrapped : function Wrapped version of ``func`` which accepts a single ``(args, kwargs)`` tuple. Examples -------- .. code-block:: python >>> @expand ... def my_func(a, b, exp=1): ... return (a * b)**exp ... >>> my_func(((2, 3), {})) 6 >>> my_func(((2, 3, 2), {})) 36 >>> my_func((tuple([]), {'b': 4, 'exp': 2, 'a': 1})) 16 This function can be used in combination with the ``collapse`` helper function, which allows more natural parameter calls .. code-block:: python >>> my_func(collapse(2, 3, exp=2)) 36 These can then be paired to enable many parameterized function calls: .. code-block:: python >>> func_calls = [collapse(a, a+1, exp=a) for a in range(5)] >>> list(map(my_func, func_calls)) [1, 2, 36, 1728, 160000] """ @functools.wraps(func) def inner(ak, *args, **kwargs): return func(*ak[0], *args, **ak[1], **kwargs) return inner
[docs]def collapse(*args, **kwargs): """ Collapse positional and keyword arguments into an (args, kwargs) tuple Intended for use with the :py:func:`expand` decorator Parameters ---------- *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- args : tuple Positional arguments tuple kwargs : dict Keyword argument dictionary """ return (args, kwargs)
[docs]def collapse_product(*args, **kwargs): """ Parameters ---------- *args Variable length list of iterables **kwargs Keyword arguments, whose values must be iterables Returns ------- iterator Generator with collapsed arguments See Also -------- Function :py:func:`collapse` Examples -------- .. code-block:: python >>> @expand ... def my_func(a, b, exp=1): ... return (a * b)**exp ... >>> product_args = list(collapse_product( ... [0, 1, 2], ... [0.5, 2], ... exp=[0, 1])) >>> product_args # doctest: +NORMALIZE_WHITESPACE [((0, 0.5), {'exp': 0}), ((0, 0.5), {'exp': 1}), ((0, 2), {'exp': 0}), ((0, 2), {'exp': 1}), ((1, 0.5), {'exp': 0}), ((1, 0.5), {'exp': 1}), ((1, 2), {'exp': 0}), ((1, 2), {'exp': 1}), ((2, 0.5), {'exp': 0}), ((2, 0.5), {'exp': 1}), ((2, 2), {'exp': 0}), ((2, 2), {'exp': 1})] >>> list(map(my_func, product_args)) [1.0, 0.0, 1, 0, 1.0, 0.5, 1, 2, 1.0, 1.0, 1, 4] """ num_args = len(args) kwarg_keys = list(kwargs.keys()) kwarg_vals = [kwargs[k] for k in kwarg_keys] format_iterations = lambda x: ( tuple(x[:num_args]), dict(zip(kwarg_keys, x[num_args:])), ) return map(format_iterations, itertools.product(*args, *kwarg_vals))
[docs]class NumpyEncoder(json.JSONEncoder): """ Helper class for json.dumps to coerce numpy objects to native python """
[docs] def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, np.int64): return int(obj) elif isinstance(obj, np.float64): return float(obj) elif isinstance(obj, np.int32): return int(obj) elif isinstance(obj, np.float32): return float(obj) return json.JSONEncoder.default(self, obj)
[docs]def checkpoint( jobs, futures, job_name, log_dir=".", extra_pending=None, extra_errors=None, extra_others=None, ): """ checkpoint and save a job state to disk """ err_msg = "lengths do not match: jobs [{}] != futures [{}]".format( len(jobs), len(futures) ) assert len(jobs) == len(futures), err_msg if extra_pending is None: extra_pending = [] if extra_errors is None: extra_errors = [] if extra_others is None: extra_others = {} pending_jobs = [ jobs[i] for i, f in enumerate(futures) if f.status == "pending" ] + extra_pending errored_jobs = [ jobs[i] for i, f in enumerate(futures) if f.status == "error" ] + extra_errors other_jobs = {} for i, f in enumerate(futures): if f.status in ["pending", "error", "finished"]: continue if f.status not in other_jobs: other_jobs[f.status] = [] other_jobs[f.status].append(jobs[i]) for k, v in extra_others.items(): if k not in other_jobs: other_jobs[k] = [] other_jobs[k].append(v) with open(os.path.join(log_dir, "{}.pending".format(job_name)), "w+") as f: f.write(json.dumps(pending_jobs, cls=NumpyEncoder)) with open(os.path.join(log_dir, "{}.err".format(job_name)), "w+") as f: f.write(json.dumps(errored_jobs, cls=NumpyEncoder)) with open(os.path.join(log_dir, "{}.other".format(job_name)), "w+") as f: f.write(json.dumps(other_jobs, cls=NumpyEncoder))
[docs]def recover(job_name, log_dir="."): """ recover pending, errored, other jobs from a checkpoint """ with open(os.path.join(log_dir, "{}.pending".format(job_name)), "r") as f: content = f.read() if len(content) == 0: pending = {} else: pending = json.loads(content) with open(os.path.join(log_dir, "{}.err".format(job_name)), "r") as f: content = f.read() if len(content) == 0: errored = {} else: errored = json.loads(content) with open(os.path.join(log_dir, "{}.other".format(job_name)), "r") as f: content = f.read() if len(content) == 0: other = {} else: other = json.loads(content) return pending, errored, other
[docs]class html(object): def __init__(self, body): self.body = body def _repr_html_(self): return self.body
_default_allowed_types = ( types.FunctionType, types.ModuleType, (type if not hasattr(types, "ClassType") else types.ClassType), types.MethodType, types.BuiltinMethodType, types.BuiltinFunctionType, )
[docs]@toolz.functoolz.curry def block_globals(obj, allowed_types=None, include_defaults=True, whitelist=None): """ Decorator to prevent globals and undefined closures in functions and classes Parameters ---------- obj : function Function to decorate. All globals not matching one of the allowed types will raise an AssertionError allowed_types : type or tuple of types, optional Types which are allowed as globals. By default, functions and modules are allowed. The full set of allowed types is drawn from the ``types`` module, and includes :py:class:`~types.FunctionType`, :py:class:`~types.ModuleType`, :py:class:`~types.MethodType`, :py:class:`~types.ClassType`, :py:class:`~types.BuiltinMethodType`, and :py:class:`~types.BuiltinFunctionType`. include_defaults : bool, optional If allowed_types is provided, setting ``include_defaults`` to True will append the default list of functions, modules, and methods to the user-passed list of allowed types. Default is True, in which case any user-passed elements will be added to the defaults described above. Setting to False will allow only the types passed in ``allowed_types``. whitelist : list of str, optional Optional list of variable names to whitelist. If a list is provided, global variables will be compared to elements of this list based on their string names. Default (None) is no whitelist. Examples -------- Wrap a function to block globals: .. code-block:: python >>> my_data = 10 >>> @block_globals ... def add_5(data): ... ''' can you spot the global? ''' ... a_number = 5 ... result = a_number + my_data ... return result # doctest: +ELLIPSIS Traceback (most recent call last): ... TypeError: Illegal <class 'int'> global found in add_5: my_data Wrapping a class will prevent globals from being used in all methods: .. code-block:: python >>> @block_globals ... class MyClass: ... ... @staticmethod ... def add_5(data): ... ''' can you spot the global? ''' ... a_number = 5 ... result = a_number + my_data ... return result # doctest: +ELLIPSIS Traceback (most recent call last): ... TypeError: Illegal <class 'int'> global found in add_5: my_data By default, functions and modules are allowed in the list of globals. You can modify this list with the ``allowed_types`` argument: .. code-block:: python >>> result_formatter = 'my number is {}' >>> @block_globals(allowed_types=str) ... def add_5(data): ... ''' only allowed globals here! ''' ... a_number = 5 ... result = a_number + data ... return result_formatter.format(result) ... >>> add_5(3) 'my number is 8' block_globals will also catch undefined references: .. code-block:: python >>> @block_globals ... def get_mean(df): ... return da.mean() # doctest: +ELLIPSIS Traceback (most recent call last): ... TypeError: Undefined global in get_mean: da """ if allowed_types is None: allowed_types = _default_allowed_types if (allowed_types is not None) and include_defaults: if not isinstance(allowed_types, collections.abc.Sequence): allowed_types = [allowed_types] allowed_types = tuple(list(allowed_types) + list(_default_allowed_types)) if whitelist is None: whitelist = [] if isinstance(obj, type): for attr in obj.__dict__: if callable(getattr(obj, attr)): setattr(obj, attr, block_globals(getattr(obj, attr))) return obj closurevars = inspect.getclosurevars(obj) for instr in dis.get_instructions(obj): if instr.opname == "LOAD_GLOBAL": if instr.argval in closurevars.builtins: continue elif (instr.argval in closurevars.globals) or ( instr.argval in closurevars.nonlocals ): if instr.argval in whitelist: continue if instr.argval in closurevars.globals: g = closurevars.globals[instr.argval] else: g = closurevars.nonlocals[instr.argval] if not isinstance(g, allowed_types): raise TypeError( "Illegal {} global found in {}: {}".format( type(g), obj.__name__, instr.argval, ) ) else: raise TypeError( "Undefined global in {}: {}".format( obj.__name__, instr.argval, ) ) @functools.wraps(obj) def inner(*args, **kwargs): return obj(*args, **kwargs) return inner
[docs]@toolz.functoolz.curry def retry_with_timeout(func, retry_freq=10, n_tries=1, use_dask=True): """Execute ``func`` ``n_tries`` times, each time only allowing ``retry_freq`` seconds for the function to complete. There are two main cases where this could be useful: 1. You have a function that you know should execute quickly, but you may get occasional errors when running it simultaneously on a large number of workers. An example of this is massively parallelized I/O operations of netcdfs on GCS. 2. You have a function that may or may not take a long time, but you want to skip it if it takes too long. There are two possible ways that this timeout function is implemented, each with pros and cons: 1. Using python's native ``threading`` module. If you are executing ``func`` outside of a ``dask`` worker, you likely will want this approach. It may be slightly faster and has the benefit of starting the timeout clock when the function starts executing (rather than when the function is *submitted* to a dask scheduler). **Note**: This approach will also work if calling ``func`` *from* a dask worker, but only if the cluster was set up such that ``threads_per_worker=1``. Otherwise, this may cause issues if used from a dask worker. 2. Using ``dask``. If you would like a dask worker to execute this function, you likely will want this approach. It can be executed from a dask worker regardless of the number of threads per worker (see above), but has the downside that the timeout clock begins once ``func`` is submitted, rather than when it begins executing. Parameters ---------- func : callable The function you would like to execute with a timeout backoff. retry_freq : float The number of seconds to wait between successive retries of ``func``. n_tries : int The number of retries to attempt before raising an error if none were successful use_dask : bool If true, will try to use the ``dask``-based implementation (see description above). If no ``Client`` instance is present, will fall back to ``use_dask=False``. Returns ------- The return value of ``func`` Raises ------ dask.distributed.TimeoutError : If the function does not execute successfully in the specified ``retry_freq``, after trying ``n_tries`` times. ValueError : If ``use_dask=True``, and a ``Client`` instance is present, but this fucntion is executed from the client (rather than as a task submitted to a worker), you will get ``ValueError("No workers found")``. Examples -------- .. code-block:: python >>> import time >>> @retry_with_timeout(retry_freq=.5, n_tries=1) ... def wait_func(timeout): ... time.sleep(timeout) >>> wait_func(.1) >>> wait_func(1) Traceback (most recent call last): ... asyncio.exceptions.TimeoutError: Func did not complete successfully in allowed time/number of retries. """ # if use_dask specified, check if there is an active client, otherwise set to false if use_dask: try: dd.get_client() except ValueError: use_dask = False @functools.wraps(func) def inner(*args, **kwargs): if use_dask: # dask version with dd.worker_client() as client: for try_n in range(n_tries): fut = client.submit(func, *args, **kwargs) try: return fut.result(timeout=retry_freq) except dd.TimeoutError: ... else: # non-dask version def this_func(q): args = q.get_nowait() kwargs = q.get_nowait() out = func(*args, **kwargs) q.put(out) for try_n in range(n_tries): q = queue.Queue() p = threading.Thread(target=this_func, args=(q,)) q.put_nowait(args) q.put_nowait(kwargs) p.start() p.join(timeout=retry_freq) if p.is_alive(): del p, q continue elif q.qsize() == 0: raise RuntimeError( "Queue is not empty. Something malfunctined in ``func``" ) return q.get() raise dd.TimeoutError( "Func did not complete successfully in allowed time/number of retries." ) return inner
[docs]def get_repo_state(repository_root : [str, None] = None) -> dict: """ Get a dictionary summarizing the current state of a repository. Parameters ---------- repository_root : str or None Path to the root of the repository to document. If ``None`` (default), the current directory will be used, and will search parent directories for a git repository. If a string is passed, parent directories will not be searched - the directory must be a repository root which conatins a ``.git`` directory. Returns ------- repo_state : dict Dictionary of repository information documenting the current state """ if repository_root is None: repo = git.Repo(".", search_parent_directories=True) else: repo = git.Repo(str(repository_root)) c = repo.commit() state = {} state["repo_last_commit_hexsha"] = str(c.hexsha) state["repo_last_commit_summary"] = str(c.summary) state["repo_last_commit_author_name"] = str(c.author.name) state["repo_last_commit_author_email"] = str(c.author.email) state["repo_last_commit_timestamp"] = str(c.authored_datetime.strftime("%c (%z)")) state["repo_remote_url"] = str(repo.remote("origin").url) try: # this fails on shallow clones state["repo_active_branch"] = str(repo.active_branch.name) except Exception: pass return state