diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 412696d..f88ef75 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,11 +1,17 @@ version: 2 +build: + os: ubuntu-22.04 + tools: + python: "3.11" + apt_packages: + - pandoc + python: - version: "3.8" install: - requirements: docs/requirements.txt - method: pip path: . - extra_requirements: - - docs - system_packages: true + +sphinx: + configuration: docs/conf.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 10e010e..e7f2a3e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,4 @@ -Babel==2.9.1 -imagesize==1.3.0 -readme-renderer>=32.0 -sphinx==4.2.0 -sphinx_rtd_theme==1.0.0 -readthedocs-sphinx-search>=0.3.2 -nbsphinx==0.8.8 +sphinx>=5.0,<8 +sphinx_rtd_theme>=1.0 +nbsphinx>=0.9 +readthedocs-sphinx-search>=0.3 diff --git a/setup.py b/setup.py index 84c3a62..2523750 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,12 @@ from setuptools import find_packages, setup, Command import platform -from distutils.sysconfig import get_config_var -from distutils.version import LooseVersion +try: + from sysconfig import get_config_var +except ImportError: + from distutils.sysconfig import get_config_var + +from packaging.version import Version # Package meta-data. @@ -48,8 +52,8 @@ if sys.platform == "darwin": if "MACOSX_DEPLOYMENT_TARGET" not in os.environ: - current_system = LooseVersion(platform.mac_ver()[0]) - python_target = LooseVersion(get_config_var("MACOSX_DEPLOYMENT_TARGET")) + current_system = Version(platform.mac_ver()[0]) + python_target = Version(get_config_var("MACOSX_DEPLOYMENT_TARGET")) if python_target < "10.9" and current_system >= "10.9": os.environ["MACOSX_DEPLOYMENT_TARGET"] = "10.9" diff --git a/skexplain/common/multiprocessing_utils.py b/skexplain/common/multiprocessing_utils.py index fa3540e..cd238f0 100644 --- a/skexplain/common/multiprocessing_utils.py +++ b/skexplain/common/multiprocessing_utils.py @@ -1,20 +1,24 @@ -import multiprocessing as mp -import itertools -from multiprocessing.pool import Pool -from datetime import datetime +"""Parallelization utilities for scikit-explain. -from tqdm import tqdm +Uses joblib.Parallel as the single backend for all parallel computation. +Provides tqdm progress bars and structured logging for failures. +""" -# from tqdm.notebook import tqdm +import multiprocessing as mp +import itertools +import logging +import time import traceback -from collections import ChainMap import warnings +import contextlib from copy import copy +from tqdm import tqdm from joblib import delayed, Parallel import joblib -import time -import contextlib + + +logger = logging.getLogger("skexplain") # Ignore the warning for joblib to set njobs=1 for # models like RandomForest @@ -23,7 +27,7 @@ @contextlib.contextmanager def tqdm_joblib(tqdm_object): - """Context manager to patch joblib to report into tqdm progress bar given as argument""" + """Context manager to patch joblib to report into tqdm progress bar given as argument.""" class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __call__(self, *args, **kwargs): @@ -39,75 +43,51 @@ def __call__(self, *args, **kwargs): tqdm_object.close() -def text_progessbar(seq, total=None): - step = 1 - tick = time.time() - while True: - time_diff = time.time() - tick - avg_speed = time_diff / step - total_str = f"of {total if total else ''}" - print( - "step", - step, - "%.2f" % time_diff, - "avg: %.2f iter/sec" % avg_speed, - total_str, - ) - step += 1 - yield next(seq) - - -all_bar_funcs = { - "tqdm": lambda args: lambda x: tqdm(x, **args), - "txt": lambda args: lambda x: text_progessbar(x, **args), - "False": lambda args: iter, - "None": lambda args: iter, -} - - -def ParallelExecutor(use_bar="tqdm", joblib_args={}, tqdm_args={}): - def aprun(bar=use_bar, **tqdm_args): - def tmp(op_iter): - if str(bar) in all_bar_funcs.keys(): - bar_func = all_bar_funcs[str(bar)](tqdm_args) - else: - raise ValueError("Value %s not supported as bar type" % bar) - return Parallel(**joblib_args)(bar_func(op_iter)) - - return tmp - - return aprun - +def to_iterator(*lists): + """Create a Cartesian product iterator from multiple lists.""" + return itertools.product(*lists) -class LogExceptions(object): - def __init__(self, func): - self.func = func - def error(self, msg, *args, **kwargs): - """Shortcut to multiprocessing's logger""" - return mp.get_logger().error(msg, *args, **kwargs) +def _resolve_n_jobs(n_jobs): + """Resolve n_jobs to a concrete positive integer. - def __call__(self, *args, **kwargs): - try: - result = self.func(*args, **kwargs) + Follows the sklearn convention: + - n_jobs=1: serial execution + - n_jobs=-1: use all CPUs + - n_jobs=-2: use all CPUs except one + - 0 < n_jobs < 1: fraction of available CPUs + - n_jobs > 1: literal number of CPUs - except Exception as e: - # Here we add some debugging help. If multiprocessing's - # debugging is on, it will arrange to log the traceback - self.error(traceback.format_exc()) - # Re-raise the original exception so the Pool worker can - # clean up - raise + Parameters + ---------- + n_jobs : int or float + Number of jobs specification. - # It was fine, give a normal answer - return result + Returns + ------- + int + Resolved number of jobs (>= 1). + """ + cpu_count = mp.cpu_count() + if n_jobs == -1: + return cpu_count + elif n_jobs < -1: + return max(1, cpu_count + 1 + n_jobs) + elif 0 < n_jobs < 1: + return max(1, int(n_jobs * cpu_count)) + else: + n_jobs = int(n_jobs) -def to_iterator(*lists): - """ - turn list - """ - return itertools.product(*lists) + if n_jobs < 1: + return 1 + if n_jobs > cpu_count: + logger.info( + "Requested %d jobs but only %d CPUs available. Using %d.", + n_jobs, cpu_count, cpu_count, + ) + return cpu_count + return n_jobs def run_parallel( @@ -115,79 +95,134 @@ def run_parallel( args_iterator, n_jobs, description=None, - kwargs={}, + kwargs=None, nprocs_to_use=None, total=None, ): + """Run a function over an iterator of arguments, optionally in parallel. + + Uses joblib.Parallel with the 'loky' backend for fork-safe parallelism. + Displays a tqdm progress bar during execution. + + Parameters + ---------- + func : callable + The function to execute. Called as ``func(*args, **kwargs)`` + for each item in ``args_iterator``. + args_iterator : iterable + Each element is a tuple of positional arguments for ``func``. + If an element is a string, it is wrapped in a tuple. + n_jobs : int or float + Number of parallel jobs. See ``_resolve_n_jobs`` for conventions. + n_jobs=1 runs in serial. + description : str, optional + Label for the tqdm progress bar. + kwargs : dict, optional + Keyword arguments passed to every call of ``func``. + nprocs_to_use : int, optional + Deprecated. Use ``n_jobs`` instead. + total : int, optional + Ignored (computed from args_iterator). + + Returns + ------- + list + Results from each call to ``func``, in order. """ - Runs a series of python scripts in parallel. Scripts uses the tqdm to create a - progress bar. If n_jobs == 1, then process is run in serial. - Args: - ------------------------- - func : callable - python function, the function to be parallelized; can be a function which issues a series of python scripts - args_iterator : iterable, list, - python iterator, the arguments of func to be iterated over - it can be the iterator itself or a series of list - n_jobs : int or float, - if int, taken as the literal number of processors to use - if float (between 0 and 1), taken as the percentage of available processors to use - kwargs : dict - keyword arguments to be passed to the func - """ + if kwargs is None: + kwargs = {} + if nprocs_to_use is not None: - warnings.warn("nprocs_to_use will deprecated and replaced by n_jobs.", DeprecationWarning) + warnings.warn( + "nprocs_to_use is deprecated; use n_jobs instead.", + DeprecationWarning, + stacklevel=2, + ) n_jobs = nprocs_to_use - iter_copy = copy(args_iterator) + # Materialize the iterator to get total count + args_list = list(args_iterator) + total = len(args_list) + n_jobs = _resolve_n_jobs(n_jobs) - total = len(list(iter_copy)) - pbar = tqdm(total=total, desc=description) - results = [] + is_parallel = n_jobs != 1 - def update(*a): - # This is called whenever a process returns a result. - # results is modified only by the main process, not by the pool workers. - pbar.update() + logger.debug( + "run_parallel: %s (%d tasks, n_jobs=%d, parallel=%s)", + description or "unnamed", total, n_jobs, is_parallel, + ) - if n_jobs == -1: - # Use all available CPUs (sklearn convention) - n_jobs = mp.cpu_count() - elif n_jobs < -1: - # Use (cpu_count + 1 + n_jobs) CPUs, e.g. -2 means all but one - n_jobs = max(1, mp.cpu_count() + 1 + n_jobs) - elif 0 < n_jobs < 1: - n_jobs = max(1, int(n_jobs * mp.cpu_count())) + start_time = time.perf_counter() + + if is_parallel: + with tqdm_joblib(tqdm(total=total, desc=description)): + results = Parallel(n_jobs=n_jobs, backend="loky")( + delayed(_safe_call)(func, _ensure_tuple(args), kwargs) + for args in args_list + ) else: - n_jobs = int(n_jobs) + results = [] + pbar = tqdm(total=total, desc=description) + for args in args_list: + results.append(_safe_call(func, _ensure_tuple(args), kwargs)) + pbar.update() + pbar.close() + + elapsed = time.perf_counter() - start_time + logger.info( + "run_parallel: %s completed in %.2fs (%d tasks, n_jobs=%d)", + description or "unnamed", elapsed, total, n_jobs, + ) - if n_jobs < 1: - n_jobs = 1 + return results - if n_jobs > mp.cpu_count(): - n_jobs = mp.cpu_count() - is_parallel = True if n_jobs != 1 else False +def _ensure_tuple(args): + """Wrap a single string arg in a tuple.""" + if isinstance(args, str): + return (args,) + return args - if is_parallel: - pool = Pool(processes=n_jobs) - ps = [] - results = [] - for args in args_iterator: - if isinstance(args, str): - args = (args,) +def _safe_call(func, args, kwargs): + """Call func with logging on failure.""" + try: + return func(*args, **kwargs) + except Exception: + logger.error( + "Parallel task failed:\n func: %s\n args: %s\n%s", + func.__name__ if hasattr(func, '__name__') else str(func), + str(args)[:200], + traceback.format_exc(), + ) + raise - if is_parallel: - p = pool.apply_async(LogExceptions(func), args=args, kwds=kwargs, callback=update) - ps.append(p) - else: - results.append(LogExceptions(func)(*args, **kwargs)) - update() - if is_parallel: - pool.close() - pool.join() - results = [p.get() for p in ps] +# Keep backward-compatible imports +def ParallelExecutor(use_bar="tqdm", joblib_args=None, tqdm_args=None): + """Create a parallel executor with a progress bar. - return results + .. deprecated:: + Use ``run_parallel`` instead. + """ + if joblib_args is None: + joblib_args = {} + if tqdm_args is None: + tqdm_args = {} + + all_bar_funcs = { + "tqdm": lambda args: lambda x: tqdm(x, **args), + "False": lambda args: iter, + "None": lambda args: iter, + } + + def aprun(bar=use_bar, **tqdm_args): + def tmp(op_iter): + if str(bar) in all_bar_funcs.keys(): + bar_func = all_bar_funcs[str(bar)](tqdm_args) + else: + raise ValueError("Value %s not supported as bar type" % bar) + return Parallel(**joblib_args)(bar_func(op_iter)) + return tmp + + return aprun diff --git a/skexplain/main/_attribution_mixin.py b/skexplain/main/_attribution_mixin.py index 03304c2..de2fa0b 100644 --- a/skexplain/main/_attribution_mixin.py +++ b/skexplain/main/_attribution_mixin.py @@ -2,6 +2,7 @@ import warnings from ..common.utils import to_xarray, is_str, is_list, is_dataset +from ._validation import track_timing class AttributionMixin: @@ -26,6 +27,7 @@ def local_contributions( lime_kws=lime_kws, ) + @track_timing def local_attributions(self, method, shap_kws=None, lime_kws=None, n_jobs=1): """ Compute the SHapley Additive Explanations (SHAP) values [13]_ [14]_ [15]_, diff --git a/skexplain/main/_curves_mixin.py b/skexplain/main/_curves_mixin.py index 6fe8963..711ee60 100644 --- a/skexplain/main/_curves_mixin.py +++ b/skexplain/main/_curves_mixin.py @@ -1,10 +1,11 @@ from ..common.utils import to_xarray, check_all_features_for_ale -from ._validation import normalize_features, normalize_estimator_names +from ._validation import normalize_features, normalize_estimator_names, track_timing class CurvesMixin: """Mixin providing ICE / PD / ALE curve methods and main-effect complexity.""" + @track_timing def ice( self, features, @@ -91,6 +92,7 @@ def ice( return results_ds + @track_timing def pd( self, features, @@ -173,6 +175,7 @@ def pd( return results_ds + @track_timing def ale( self, features=None, @@ -271,6 +274,7 @@ def ale( return results_ds + @track_timing def main_effect_complexity(self, ale, estimator_names=None, max_segments=10, approx_error=0.05): """ Compute the Main Effect Complexity (MEC; Molnar et al. 2019) [5]_. diff --git a/skexplain/main/_importance_mixin.py b/skexplain/main/_importance_mixin.py index e91ac17..bf8f17a 100644 --- a/skexplain/main/_importance_mixin.py +++ b/skexplain/main/_importance_mixin.py @@ -3,12 +3,13 @@ from ..common.utils import is_str, to_xarray, check_all_features_for_ale from ..common.importance_utils import retrieve_important_vars, combine_top_features, compute_importance -from ._validation import normalize_features, normalize_estimator_names +from ._validation import normalize_features, normalize_estimator_names, track_timing class ImportanceMixin: """Mixin providing feature-importance methods for ExplainToolkit.""" + @track_timing def permutation_importance( self, n_vars, @@ -188,6 +189,7 @@ def permutation_importance( return results_ds + @track_timing def grouped_permutation_importance( self, perm_method, @@ -336,6 +338,7 @@ def grouped_permutation_importance( return results_ds + @track_timing def ale_variance( self, ale, diff --git a/skexplain/main/_interaction_mixin.py b/skexplain/main/_interaction_mixin.py index b52d158..e6b69b2 100644 --- a/skexplain/main/_interaction_mixin.py +++ b/skexplain/main/_interaction_mixin.py @@ -2,12 +2,13 @@ import xarray as xr from ..common.utils import check_all_features_for_ale -from ._validation import normalize_estimator_names +from ._validation import normalize_estimator_names, track_timing class InteractionMixin: """Mixin providing feature-interaction methods for ExplainToolkit.""" + @track_timing def perm_based_interaction( self, features, @@ -117,6 +118,7 @@ def perm_based_interaction( return results_ds + @track_timing def friedman_h_stat( self, dataset_1d=None, dataset_2d=None, features=None, estimator_names=None, **kwargs ): @@ -217,6 +219,7 @@ def friedman_h_stat( return results_ds + @track_timing def interaction_strength(self, ale, estimator_names=None, **kwargs): """ Compute the InterAction Strength (IAS) statistic from Molnar et al. (2019) [5]_. @@ -289,6 +292,7 @@ def interaction_strength(self, ale, estimator_names=None, **kwargs): return results_ds + @track_timing def sobol_indices(self, n_bootstrap=5000, class_index=1): """ Compute the 1st Order and Total order Sobol Indices. Useful for diagnosing feature diff --git a/skexplain/main/_validation.py b/skexplain/main/_validation.py index 8c29624..cc6a14a 100644 --- a/skexplain/main/_validation.py +++ b/skexplain/main/_validation.py @@ -1,8 +1,13 @@ """Shared validation and normalization helpers for ExplainToolkit methods.""" import itertools +import time +import functools +import logging from ..common.utils import is_str, is_list +logger = logging.getLogger("skexplain") + def normalize_features(features, all_features, allow_2d=False): """Normalize a features argument to a list. @@ -55,3 +60,27 @@ def normalize_estimator_names(names, default_names): if is_str(names): return [names] return list(names) + + +def track_timing(method): + """Decorator that records computation time in the returned dataset's attrs. + + Adds ``computation_time_seconds`` to ``self.attrs_dict`` before + ``_append_attributes`` is called. Also logs the elapsed time. + + Only works on methods whose ``self`` has an ``attrs_dict`` attribute. + """ + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + start = time.perf_counter() + result = method(self, *args, **kwargs) + elapsed = time.perf_counter() - start + self.attrs_dict["computation_time_seconds"] = round(elapsed, 3) + # Update attrs on the returned result if it has attrs (Dataset/DataFrame) + if hasattr(result, "attrs"): + result.attrs["computation_time_seconds"] = round(elapsed, 3) + logger.info( + "%s completed in %.2fs", method.__name__, elapsed, + ) + return result + return wrapper