From ba8050d2725ca0e278a783a9649be57bbf7cc820 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 18 May 2026 12:22:45 +0100 Subject: [PATCH 1/4] triage with cache and restart --- .github/triage/jax_toolbox_triage/args.py | 61 ++++- .github/triage/jax_toolbox_triage/logic.py | 59 +++-- .github/triage/jax_toolbox_triage/main.py | 4 +- .github/triage/jax_toolbox_triage/summary.py | 122 +++++++++- .../triage/jax_toolbox_triage/triage_tool.py | 56 ++++- .github/triage/jax_toolbox_triage/utils.py | 12 +- .github/triage/tests/test_restart.py | 220 ++++++++++++++++++ 7 files changed, 485 insertions(+), 49 deletions(-) create mode 100644 .github/triage/tests/test_restart.py diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 729e88167..38b836c49 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -3,6 +3,7 @@ import getpass import os import pathlib +import sys import tempfile import typing import warnings @@ -53,6 +54,11 @@ def parse_override_remotes(s: str) -> typing.Dict[str, str]: def parse_args(args=None) -> argparse.Namespace: + raw_args = sys.argv[1:] if args is None else args + output_prefix_supplied = any( + arg == "--output-prefix" or arg.startswith("--output-prefix=") + for arg in raw_args + ) parser = argparse.ArgumentParser( description=""" Triage failures in JAX/XLA-related tests. The expectation is that the given @@ -100,7 +106,7 @@ def parse_args(args=None) -> argparse.Namespace: ) parser.add_argument( "--output-prefix", - default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"), + default=None, help=""" Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS. An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is @@ -108,6 +114,23 @@ def parse_args(args=None) -> argparse.Namespace: PREFIX-summary.json""", type=pathlib.Path, ) + parser.add_argument( + "--restart", + action="store_true", + help=""" + Restart a previous triage run by loading completed records from + --restart-folder/summary.json. The summary file is the source of truth; + incomplete output directories without summary records are ignored. + """, + ) + parser.add_argument( + "--restart-folder", + type=pathlib.Path, + help=""" + Output folder from a previous triage run. Must contain summary.json and is + used as --output-prefix during --restart. + """, + ) parser.add_argument( "--skip-precondition-checks", action="store_true", @@ -308,6 +331,32 @@ def parse_args(args=None) -> argparse.Namespace: help="The name of the main branch (e.g. main) to derive cherry-picks from", ) args = parser.parse_args(args=args) + if args.restart: + if args.restart_folder is None: + raise Exception("--restart requires --restart-folder") + args.restart_folder = args.restart_folder.resolve() + summary_file = args.restart_folder / "summary.json" + if not summary_file.exists(): + raise Exception( + f"--restart-folder must contain summary.json: {summary_file}" + ) + if ( + output_prefix_supplied + and args.output_prefix is not None + and args.output_prefix.resolve() != args.restart_folder + ): + raise Exception( + "--output-prefix must match --restart-folder when restarting" + ) + args.output_prefix = args.restart_folder + else: + if args.restart_folder is not None: + raise Exception("--restart-folder requires --restart") + if args.output_prefix is None: + args.output_prefix = pathlib.Path( + datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S") + ) + assert args.container_runtime in { "docker", "pyxis", @@ -348,9 +397,7 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), ( - "For local runtime, --passing-versions and --failing-versions must be provided." - ) + ), "For local runtime, --passing-versions and --failing-versions must be provided." assert ( args.container is None and args.start_date is None @@ -395,9 +442,9 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert args.container is not None, ( - "--container must be passed for the container-level search" - ) + assert ( + args.container is not None + ), "--container must be passed for the container-level search" args.optional_software = optional_software.copy() if args.exclude_transformer_engine: diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index c63c43b6e..e6fe92b50 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -321,6 +321,17 @@ def remove_build_failures(ver): _REPETITION_KEY = "#rep" +def version_cache_key( + versions: typing.Dict[str, str], + *, + repetition: int = 0, +) -> FlatVersionDict: + return tuple( + sorted(_strip_build_failures(versions).items()) + + [(_REPETITION_KEY, str(repetition))] + ) + + def _version_search( *, versions: typing.OrderedDict[ @@ -330,6 +341,7 @@ def _version_search( logger: logging.Logger, skip_precondition_checks: bool, result_cache: typing.Dict[FlatVersionDict, TestResult], + preloaded_cache_keys: typing.Set[FlatVersionDict], confirmation_iterations: int, check_success_before_failure: bool = True, ) -> typing.Tuple[typing.Dict[str, str], TestResult, typing.Optional[TestResult]]: @@ -342,16 +354,6 @@ def _version_search( versions, ) - def _cache_key( - versions: typing.Dict[str, str], - *, - repetition: int = 0, - ) -> FlatVersionDict: - return tuple( - sorted(_strip_build_failures(versions).items()) - + [(_REPETITION_KEY, str(repetition))] - ) - def build_cached( bisect_versions, *, @@ -359,11 +361,12 @@ def build_cached( test_output_log_level: int = logging.DEBUG, repetition: int = 0, ): - cache_key = _cache_key(bisect_versions, repetition=repetition) + cache_key = version_cache_key(bisect_versions, repetition=repetition) bisect_result = result_cache.get(cache_key) if bisect_result is not None: - if assert_miss: + if assert_miss and cache_key not in preloaded_cache_keys: raise Exception("Unexpected cache hit!") + logger.info(f"Reusing cached result for {dict(cache_key)}") return bisect_result bisect_result = build_and_test( versions=_strip_build_failures(bisect_versions), @@ -477,20 +480,20 @@ def find_successful_build(versions): # succeeds so we can do the same. The somewhat arbitrary logic here is to take # the shorter y??????n run and refine it in the hope one of the ? is a y. assert ( - result_cache.get(_cache_key(_earliest_versions(versions))).result + result_cache.get(version_cache_key(_earliest_versions(versions))).result == TestExecutionOutcome.TEST_SUCCESS ) build_statuses = [(True, {p: 0 for p in versions})] for n in range(1, len(versions[primary]) - 1): versions_n, indices_n = get_versions(primary_index=n, versions=versions) - result_n = result_cache.get(_cache_key(versions_n)) + result_n = result_cache.get(version_cache_key(versions_n)) assert ( result_n is None or result_n.result == TestExecutionOutcome.BUILD_FAILURE ) build_statuses.append((None if result_n is None else False, indices_n)) assert ( - result_cache.get(_cache_key(_latest_versions(versions))).result + result_cache.get(version_cache_key(_latest_versions(versions))).result == TestExecutionOutcome.TEST_FAILURE ) build_statuses.append((True, {p: len(vs) - 1 for p, vs in versions.items()})) @@ -565,9 +568,9 @@ def _index(pkg, ver): for package, index in indices.items(): versions[package] = versions[package][: index + 1] else: - assert bisect_result.result == TestExecutionOutcome.BUILD_FAILURE, ( - bisect_result - ) + assert ( + bisect_result.result == TestExecutionOutcome.BUILD_FAILURE + ), bisect_result # Did not succeed in finding a version of `primary` that builds. This does # not quite mean that all versions fail, as the algorithm will not try all # versions in ranges with failures at both ends @@ -586,11 +589,11 @@ def _index(pkg, ver): for n in range(1, n_primary - 1): versions_n, _ = get_versions(primary_index=n, versions=versions) # Should have been a build failure if tested. - result_n = result_cache.get(_cache_key(versions_n)) + result_n = result_cache.get(version_cache_key(versions_n)) if result_n is not None: - assert result_n.result == TestExecutionOutcome.BUILD_FAILURE, ( - result_n - ) + assert ( + result_n.result == TestExecutionOutcome.BUILD_FAILURE + ), result_n build_fail_commits.append(versions_n[primary]) logger.warning( f"Could not triage {primary} to a single version due to build " @@ -648,7 +651,7 @@ def _index(pkg, ver): # `blame` represents the last-known-good test result, first-known-bad was seen # earlier, or possibly not at all e.g. if `skip_precondition_checks` is True # and first-known-bad was the end of the search range. - first_known_bad_result = result_cache.get(_cache_key(first_known_bad)) + first_known_bad_result = result_cache.get(version_cache_key(first_known_bad)) if first_known_bad_result is None: if skip_precondition_checks: logger.info( @@ -687,6 +690,7 @@ def _index(pkg, ver): logger=logger, skip_precondition_checks=True, result_cache=result_cache, + preloaded_cache_keys=preloaded_cache_keys, confirmation_iterations=confirmation_iterations, ) @@ -701,6 +705,8 @@ def version_search( skip_precondition_checks: bool, check_success_before_failure: bool = True, confirmation_iterations: int = 1, + result_cache: typing.Optional[typing.Dict[FlatVersionDict, TestResult]] = None, + preloaded_cache_keys: typing.Optional[typing.Set[FlatVersionDict]] = None, ) -> typing.Tuple[ typing.Dict[str, str], TestResult, @@ -728,6 +734,10 @@ def version_search( False, but the other fields can be used to obtain stdout+stderr and output files from those test invocations. """ + if result_cache is None: + result_cache = {} + if preloaded_cache_keys is None: + preloaded_cache_keys = set(result_cache.keys()) return _version_search( versions=versions, build_and_test=build_and_test, @@ -735,5 +745,6 @@ def version_search( skip_precondition_checks=skip_precondition_checks, check_success_before_failure=check_success_before_failure, confirmation_iterations=confirmation_iterations, - result_cache={}, + result_cache=result_cache, + preloaded_cache_keys=preloaded_cache_keys, ) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0217b70f8..0a82b8a97 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -8,7 +8,9 @@ def main() -> None: Main entry point for the triage tool. """ args = parse_args() - logger = get_logger(args.output_prefix) + logger = get_logger(args.output_prefix, append=args.restart) + if args.restart: + logger.info(f"Restarting from {args.restart_folder}") tool = TriageTool(args, logger) passing_url, failing_url = tool.find_container_range() passing_versions, failing_versions = tool.gather_version_info( diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py index 8545d1447..5a02ea0d7 100644 --- a/.github/triage/jax_toolbox_triage/summary.py +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -2,7 +2,12 @@ import logging import pathlib import typing -from .logic import TestResult +from .logic import ( + FlatVersionDict, + TestExecutionOutcome, + TestResult, + version_cache_key, +) def add_summary_record( @@ -46,6 +51,105 @@ def add_summary_record( return data +def load_summary(output_prefix: pathlib.Path) -> typing.Dict[str, typing.Any]: + """ + Load the JSON summary for a previous or current triage run. + """ + with open(output_prefix / "summary.json", "r") as ifile: + return json.load(ifile) + + +def _parse_result(value) -> TestExecutionOutcome: + if isinstance(value, TestExecutionOutcome): + return value + if isinstance(value, bool): + return ( + TestExecutionOutcome.TEST_SUCCESS + if value + else TestExecutionOutcome.TEST_FAILURE + ) + if isinstance(value, str): + name = value.rsplit(".", 1)[-1] + return TestExecutionOutcome[name] + raise ValueError(f"Cannot parse test result from {value!r}") + + +def _record_output_directory( + output_prefix: pathlib.Path, record: typing.Dict[str, typing.Any] +) -> pathlib.Path: + out_dir = pathlib.Path(record["output_directory"]) + if out_dir.exists() or not out_dir.is_absolute(): + return out_dir + copied_dir = output_prefix / out_dir.name + if copied_dir.exists(): + return copied_dir.resolve() + return out_dir + + +def version_result_cache_from_summary( + output_prefix: pathlib.Path, + packages: typing.Iterable[str], + summary: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Dict[FlatVersionDict, TestResult]: + """ + Reconstruct completed version-level build/test results from summary.json. + + The summary file is treated as the transaction log. Output directories that exist + without a corresponding summary record are ignored by construction. + """ + if summary is None: + summary = load_summary(output_prefix) + packages = set(packages) + cache = {} + for record in summary.get("versions", []): + if not isinstance(record, dict): + continue + if not packages <= record.keys(): + logging.warning( + "Ignoring restart summary record that is missing package keys: %s", + sorted(packages - record.keys()), + ) + continue + if "result" not in record or "output_directory" not in record: + logging.warning("Ignoring incomplete restart summary record: %s", record) + continue + versions = {package: record[package] for package in packages} + repetition = int(record.get("test_repetition", 0)) + key = version_cache_key(versions, repetition=repetition) + cache[key] = TestResult( + build_stdouterr=None, + host_output_directory=_record_output_directory(output_prefix, record), + result=_parse_result(record["result"]), + stdouterr=None, + ) + return cache + + +def container_result_cache_from_summary( + output_prefix: pathlib.Path, + summary: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Dict[str, TestResult]: + """ + Reconstruct completed container-level test results from summary.json. + """ + if summary is None: + summary = load_summary(output_prefix) + cache = {} + for record in summary.get("container", []): + if not isinstance(record, dict): + continue + if not {"container", "result", "output_directory"} <= record.keys(): + logging.warning("Ignoring incomplete restart container record: %s", record) + continue + cache[record["container"]] = TestResult( + build_stdouterr=None, + host_output_directory=_record_output_directory(output_prefix, record), + result=_parse_result(record["result"]), + stdouterr=None, + ) + return cache + + def create_output_symlinks( output_prefix: pathlib.Path, last_known_good: typing.Optional[TestResult], @@ -67,13 +171,19 @@ def create_output_symlinks( def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: if result is None: return - symlink_path = (output_prefix / symlink_name).resolve() - assert not symlink_path.exists(), symlink_path - assert symlink_path.parent == result.host_output_directory.parent, ( - symlink_path, + symlink_path = output_prefix / symlink_name + if symlink_path.exists() or symlink_path.is_symlink(): + assert symlink_path.resolve() == result.host_output_directory.resolve(), ( + symlink_path, + result.host_output_directory, + ) + return + absolute_symlink_path = symlink_path.resolve() + assert absolute_symlink_path.parent == result.host_output_directory.parent, ( + absolute_symlink_path, result.host_output_directory, ) - symlink_path.symlink_to(result.host_output_directory) + absolute_symlink_path.symlink_to(result.host_output_directory) symlink(last_known_good, "last-known-good") symlink(first_known_bad, "first-known-bad") diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 870dff61f..38386feb9 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -18,7 +18,13 @@ CouldNotReproduceSuccess, ) from .versions import get_versions_dirs_env -from .summary import add_summary_record, create_output_symlinks +from .summary import ( + add_summary_record, + container_result_cache_from_summary, + create_output_symlinks, + load_summary, + version_result_cache_from_summary, +) from .bisect import get_commit_history from .utils import ( container_url as container_url_base, @@ -51,6 +57,16 @@ def __init__(self, args, logger): self.packages_with_scripts = set() self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) self.check_success_before_failure = True + self.restart_summary = ( + load_summary(self.args.restart_folder) if args.restart else {} + ) + self.container_result_cache = ( + container_result_cache_from_summary( + self.args.restart_folder, self.restart_summary + ) + if args.restart + else {} + ) self.logger.info("Arguments:") for k, v in vars(self.args).items(): @@ -80,9 +96,13 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + if out_dir.exists() and self.args.restart: + base_out_dir = out_dir + n = 1 + while out_dir.exists(): + out_dir = base_out_dir.with_name(f"{base_out_dir.name}-restart-{n}") + n += 1 + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -126,9 +146,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." with self._make_container(container_url) as worker: url_versions, dirs, env = get_versions_dirs_env( @@ -254,6 +274,11 @@ def _check_container_by_url( TestResult: The result of the test, including whether it passed and the output. """ before = time.monotonic() + cached_result = self.container_result_cache.get(container_url) + if cached_result is not None: + self.logger.info(f"Reusing cached container result for {container_url}") + return cached_result + out_dir = self._test_output_directory(container_url, None) with self._make_container( @@ -480,6 +505,8 @@ def _build_and_test( summary = { "build_time": build_time, "container": self.bisection_url, + "output_directory": out_dir.as_posix(), + "test_repetition": test_repetition, } summary.update(versions) if build_pass: @@ -686,6 +713,19 @@ def run_version_bisection( # Run the version-level bisection self.logger.info("Running version-level bisection...") + result_cache = {} + preloaded_cache_keys = set() + if self.args.restart: + result_cache = version_result_cache_from_summary( + self.args.restart_folder, + package_versions.keys(), + self.restart_summary, + ) + preloaded_cache_keys = set(result_cache.keys()) + self.logger.info( + f"Loaded {len(result_cache)} completed version-level result(s) " + f"from {self.args.restart_folder / 'summary.json'}" + ) try: result, last_known_good, first_known_bad = version_search( versions=package_versions, @@ -694,6 +734,8 @@ def run_version_bisection( skip_precondition_checks=self.args.skip_precondition_checks, check_success_before_failure=self.check_success_before_failure, confirmation_iterations=self.args.confirmation_iterations, + result_cache=result_cache, + preloaded_cache_keys=preloaded_cache_keys, ) except CouldNotReproduceFailure as e: if ( diff --git a/.github/triage/jax_toolbox_triage/utils.py b/.github/triage/jax_toolbox_triage/utils.py index d18aea784..623c5d596 100644 --- a/.github/triage/jax_toolbox_triage/utils.py +++ b/.github/triage/jax_toolbox_triage/utils.py @@ -25,16 +25,20 @@ def container_url( return f"ghcr.io/nvidia/{container}:nightly-{date_str}" -def get_logger(output_prefix: pathlib.Path) -> logging.Logger: - output_prefix.mkdir() +def get_logger(output_prefix: pathlib.Path, append: bool = False) -> logging.Logger: + output_prefix.mkdir(exist_ok=append) logger = logging.getLogger("triage") logger.setLevel(logging.DEBUG) + for handler in logger.handlers: + handler.close() + logger.handlers.clear() formatter = logging.Formatter( fmt="[%(levelname)s] %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) + mode = "a" if append else "w" console = logging.StreamHandler() - trace_file = logging.FileHandler(filename=output_prefix / "info.log", mode="w") - debug_file = logging.FileHandler(filename=output_prefix / "debug.log", mode="w") + trace_file = logging.FileHandler(filename=output_prefix / "info.log", mode=mode) + debug_file = logging.FileHandler(filename=output_prefix / "debug.log", mode=mode) console.setLevel(logging.INFO) trace_file.setLevel(logging.INFO) debug_file.setLevel(logging.DEBUG) diff --git a/.github/triage/tests/test_restart.py b/.github/triage/tests/test_restart.py new file mode 100644 index 000000000..52900f7b7 --- /dev/null +++ b/.github/triage/tests/test_restart.py @@ -0,0 +1,220 @@ +import collections +import datetime +import json +import logging + +import pytest + +from jax_toolbox_triage.args import parse_args +from jax_toolbox_triage.logic import ( + TestExecutionOutcome, + TestResult, + version_search, +) +from jax_toolbox_triage.summary import version_result_cache_from_summary +from jax_toolbox_triage.triage_tool import TriageTool + + +start_date = datetime.datetime(2026, 1, 1) + + +def make_commits(): + return collections.OrderedDict( + [ + ("xla", [("xla-good", start_date)]), + ( + "jax", + [ + ("jax-good", start_date), + ("jax-bad", start_date + datetime.timedelta(days=1)), + ], + ), + ] + ) + + +def make_result(output_directory, outcome): + return TestResult( + build_stdouterr=None, + host_output_directory=output_directory, + result=outcome, + stdouterr=None, + ) + + +def test_restart_requires_summary_json(tmp_path): + with pytest.raises(Exception, match="--restart requires --restart-folder"): + parse_args( + [ + "--restart", + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + + with pytest.raises(Exception, match="summary.json"): + parse_args( + [ + "--restart", + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + + +def test_restart_uses_restart_folder_as_output_prefix(tmp_path): + (tmp_path / "summary.json").write_text("{}") + + args = parse_args( + [ + "--restart", + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + + assert args.output_prefix == tmp_path.resolve() + + +def test_version_search_reuses_preloaded_restart_cache(tmp_path): + good_dir = tmp_path / "good" + bad_dir = tmp_path / "bad" + good_dir.mkdir() + bad_dir.mkdir() + summary = { + "versions": [ + { + "container": "container-url", + "xla": "xla-good", + "jax": "jax-good", + "output_directory": str(good_dir), + "result": "TestExecutionOutcome.TEST_SUCCESS", + }, + { + "container": "container-url", + "xla": "xla-good", + "jax": "jax-bad", + "output_directory": str(bad_dir), + "result": "TestExecutionOutcome.TEST_FAILURE", + }, + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + result_cache = version_result_cache_from_summary(tmp_path, make_commits().keys()) + + def build_and_test(**kwargs): + raise AssertionError(f"Unexpected build/test during restart: {kwargs}") + + result, last_known_good, first_known_bad = version_search( + versions=make_commits(), + build_and_test=build_and_test, + logger=logging.getLogger("triage-restart-test"), + skip_precondition_checks=False, + confirmation_iterations=0, + result_cache=result_cache, + ) + + assert result == { + "jax_bad": "jax-bad", + "jax_good": "jax-good", + "xla_ref": "xla-good", + } + assert last_known_good.result == TestExecutionOutcome.TEST_SUCCESS + assert first_known_bad.result == TestExecutionOutcome.TEST_FAILURE + + +def test_restart_suffixed_stale_output_directory(tmp_path): + (tmp_path / "summary.json").write_text("{}") + args = parse_args( + [ + "--restart", + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + tool = TriageTool(args, logging.getLogger("triage-restart-test")) + + stale = tool._test_output_directory("local", {"jax": "jax-good"}) + retry = tool._test_output_directory("local", {"jax": "jax-good"}) + + assert stale.name in retry.name + assert retry.name.endswith("-restart-1") + + +def test_triage_tool_loads_restart_cache(tmp_path): + good_dir = tmp_path / "good" + bad_dir = tmp_path / "bad" + good_dir.mkdir() + bad_dir.mkdir() + summary = { + "versions": [ + { + "container": "local", + "xla": "xla-good", + "jax": "jax-good", + "output_directory": str(good_dir), + "result": "TestExecutionOutcome.TEST_SUCCESS", + }, + { + "container": "local", + "xla": "xla-good", + "jax": "jax-bad", + "output_directory": str(bad_dir), + "result": "TestExecutionOutcome.TEST_FAILURE", + }, + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + args = parse_args( + [ + "--restart", + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "--confirmation-iterations=0", + "test-command", + ] + ) + tool = TriageTool(args, logging.getLogger("triage-restart-test")) + tool.bisection_url = "local" + tool._gather_histories = lambda worker, passing, failing: make_commits() + tool._check_installation_scripts = lambda worker: set() + + def build_and_test(**kwargs): + raise AssertionError(f"Unexpected build/test during restart: {kwargs}") + + tool._build_and_test = build_and_test + + result = tool.run_version_bisection( + {"jax": "jax-good", "xla": "xla-good"}, + {"jax": "jax-bad", "xla": "xla-good"}, + ) + + assert result["result"]["jax_good"] == "jax-good" + assert result["result"]["jax_bad"] == "jax-bad" From f0bab186e1a72d5e141d1d8f506f9c4da9aa199b Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 18 May 2026 16:29:24 +0100 Subject: [PATCH 2/4] fix args --- .github/triage/jax_toolbox_triage/args.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 38b836c49..028450689 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -3,7 +3,6 @@ import getpass import os import pathlib -import sys import tempfile import typing import warnings @@ -54,11 +53,6 @@ def parse_override_remotes(s: str) -> typing.Dict[str, str]: def parse_args(args=None) -> argparse.Namespace: - raw_args = sys.argv[1:] if args is None else args - output_prefix_supplied = any( - arg == "--output-prefix" or arg.startswith("--output-prefix=") - for arg in raw_args - ) parser = argparse.ArgumentParser( description=""" Triage failures in JAX/XLA-related tests. The expectation is that the given @@ -341,8 +335,7 @@ def parse_args(args=None) -> argparse.Namespace: f"--restart-folder must contain summary.json: {summary_file}" ) if ( - output_prefix_supplied - and args.output_prefix is not None + args.output_prefix is not None and args.output_prefix.resolve() != args.restart_folder ): raise Exception( From 82d71dd152271cbddc4623ed6b4548850a777fc1 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 28 May 2026 13:57:52 +0100 Subject: [PATCH 3/4] Update .github/triage/jax_toolbox_triage/logic.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/logic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index e6fe92b50..2ee05d092 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -705,7 +705,7 @@ def version_search( skip_precondition_checks: bool, check_success_before_failure: bool = True, confirmation_iterations: int = 1, - result_cache: typing.Optional[typing.Dict[FlatVersionDict, TestResult]] = None, + result_cache: typing.Dict[FlatVersionDict, TestResult] = {}, preloaded_cache_keys: typing.Optional[typing.Set[FlatVersionDict]] = None, ) -> typing.Tuple[ typing.Dict[str, str], From 48b65e2eb882461fc877c1dd27a2f21930838ef3 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 28 May 2026 14:24:26 +0100 Subject: [PATCH 4/4] fix based on comments --- .github/triage/jax_toolbox_triage/args.py | 20 +--- .github/triage/jax_toolbox_triage/logic.py | 21 +--- .github/triage/jax_toolbox_triage/main.py | 5 +- .github/triage/jax_toolbox_triage/summary.py | 97 ++++++++----------- .../triage/jax_toolbox_triage/triage_tool.py | 39 ++++---- .github/triage/jax_toolbox_triage/utils.py | 3 - .github/triage/tests/test_restart.py | 63 ++++++------ 7 files changed, 103 insertions(+), 145 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 028450689..7cf3b7cfc 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -108,21 +108,13 @@ def parse_args(args=None) -> argparse.Namespace: PREFIX-summary.json""", type=pathlib.Path, ) - parser.add_argument( - "--restart", - action="store_true", - help=""" - Restart a previous triage run by loading completed records from - --restart-folder/summary.json. The summary file is the source of truth; - incomplete output directories without summary records are ignored. - """, - ) parser.add_argument( "--restart-folder", type=pathlib.Path, help=""" - Output folder from a previous triage run. Must contain summary.json and is - used as --output-prefix during --restart. + Restart a previous triage run from this output folder. The folder must + contain summary.json, which is used as the source of truth for completed + records. Incomplete output directories without summary records are ignored. """, ) parser.add_argument( @@ -325,9 +317,7 @@ def parse_args(args=None) -> argparse.Namespace: help="The name of the main branch (e.g. main) to derive cherry-picks from", ) args = parser.parse_args(args=args) - if args.restart: - if args.restart_folder is None: - raise Exception("--restart requires --restart-folder") + if args.restart_folder is not None: args.restart_folder = args.restart_folder.resolve() summary_file = args.restart_folder / "summary.json" if not summary_file.exists(): @@ -343,8 +333,6 @@ def parse_args(args=None) -> argparse.Namespace: ) args.output_prefix = args.restart_folder else: - if args.restart_folder is not None: - raise Exception("--restart-folder requires --restart") if args.output_prefix is None: args.output_prefix = pathlib.Path( datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S") diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index 2ee05d092..905c1167a 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -341,7 +341,6 @@ def _version_search( logger: logging.Logger, skip_precondition_checks: bool, result_cache: typing.Dict[FlatVersionDict, TestResult], - preloaded_cache_keys: typing.Set[FlatVersionDict], confirmation_iterations: int, check_success_before_failure: bool = True, ) -> typing.Tuple[typing.Dict[str, str], TestResult, typing.Optional[TestResult]]: @@ -357,15 +356,12 @@ def _version_search( def build_cached( bisect_versions, *, - assert_miss: bool = False, test_output_log_level: int = logging.DEBUG, repetition: int = 0, ): cache_key = version_cache_key(bisect_versions, repetition=repetition) bisect_result = result_cache.get(cache_key) if bisect_result is not None: - if assert_miss and cache_key not in preloaded_cache_keys: - raise Exception("Unexpected cache hit!") logger.info(f"Reusing cached result for {dict(cache_key)}") return bisect_result bisect_result = build_and_test( @@ -384,7 +380,6 @@ def _check_success(): for n in range(confirmation_iterations + 1): check_pass = build_cached( _earliest_versions(versions), - assert_miss=True, repetition=-n, ) if check_pass.result == TestExecutionOutcome.TEST_SUCCESS: @@ -406,7 +401,6 @@ def _check_failure(): for n in range(confirmation_iterations + 1): check_fail = build_cached( _latest_versions(versions), - assert_miss=True, repetition=-n, test_output_log_level=logging.INFO if n == 0 else logging.DEBUG, ) @@ -668,13 +662,9 @@ def _index(pkg, ver): # `blame_versions` are last-known-good, `first_known_bad` are what they say. # We just tested `blame_versions`, so start with `first_known_bad`. for n in range(confirmation_iterations): - confirm_bad = build_cached( - first_known_bad, assert_miss=True, repetition=n + 1 - ) + confirm_bad = build_cached(first_known_bad, repetition=n + 1) assert confirm_bad.result == TestExecutionOutcome.TEST_FAILURE - confirm_good = build_cached( - blame_versions, assert_miss=True, repetition=n + 1 - ) + confirm_good = build_cached(blame_versions, repetition=n + 1) assert confirm_good.result == TestExecutionOutcome.TEST_SUCCESS return ret, blame, first_known_bad_result else: @@ -690,7 +680,6 @@ def _index(pkg, ver): logger=logger, skip_precondition_checks=True, result_cache=result_cache, - preloaded_cache_keys=preloaded_cache_keys, confirmation_iterations=confirmation_iterations, ) @@ -705,8 +694,7 @@ def version_search( skip_precondition_checks: bool, check_success_before_failure: bool = True, confirmation_iterations: int = 1, - result_cache: typing.Dict[FlatVersionDict, TestResult] = {}, - preloaded_cache_keys: typing.Optional[typing.Set[FlatVersionDict]] = None, + result_cache: typing.Optional[typing.Dict[FlatVersionDict, TestResult]] = None, ) -> typing.Tuple[ typing.Dict[str, str], TestResult, @@ -736,8 +724,6 @@ def version_search( """ if result_cache is None: result_cache = {} - if preloaded_cache_keys is None: - preloaded_cache_keys = set(result_cache.keys()) return _version_search( versions=versions, build_and_test=build_and_test, @@ -746,5 +732,4 @@ def version_search( check_success_before_failure=check_success_before_failure, confirmation_iterations=confirmation_iterations, result_cache=result_cache, - preloaded_cache_keys=preloaded_cache_keys, ) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0a82b8a97..cbecb94a6 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -8,8 +8,9 @@ def main() -> None: Main entry point for the triage tool. """ args = parse_args() - logger = get_logger(args.output_prefix, append=args.restart) - if args.restart: + restart = args.restart_folder is not None + logger = get_logger(args.output_prefix, append=restart) + if restart: logger.info(f"Restarting from {args.restart_folder}") tool = TriageTool(args, logger) passing_url, failing_url = tool.find_container_range() diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py index 5a02ea0d7..c28d7be3b 100644 --- a/.github/triage/jax_toolbox_triage/summary.py +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -3,12 +3,15 @@ import pathlib import typing from .logic import ( - FlatVersionDict, TestExecutionOutcome, TestResult, version_cache_key, ) +SummaryCacheKey = typing.Tuple[str, typing.Any] +CONTAINER_CACHE_SECTION = "container" +VERSION_CACHE_SECTION = "versions" + def add_summary_record( output_prefix: pathlib.Path, @@ -59,21 +62,6 @@ def load_summary(output_prefix: pathlib.Path) -> typing.Dict[str, typing.Any]: return json.load(ifile) -def _parse_result(value) -> TestExecutionOutcome: - if isinstance(value, TestExecutionOutcome): - return value - if isinstance(value, bool): - return ( - TestExecutionOutcome.TEST_SUCCESS - if value - else TestExecutionOutcome.TEST_FAILURE - ) - if isinstance(value, str): - name = value.rsplit(".", 1)[-1] - return TestExecutionOutcome[name] - raise ValueError(f"Cannot parse test result from {value!r}") - - def _record_output_directory( output_prefix: pathlib.Path, record: typing.Dict[str, typing.Any] ) -> pathlib.Path: @@ -86,21 +74,42 @@ def _record_output_directory( return out_dir -def version_result_cache_from_summary( +def result_cache_from_summary( output_prefix: pathlib.Path, - packages: typing.Iterable[str], + packages: typing.Iterable[str] = (), summary: typing.Optional[typing.Dict[str, typing.Any]] = None, -) -> typing.Dict[FlatVersionDict, TestResult]: +) -> typing.Dict[SummaryCacheKey, TestResult]: """ - Reconstruct completed version-level build/test results from summary.json. + Reconstruct completed build/test results from summary.json. The summary file is treated as the transaction log. Output directories that exist without a corresponding summary record are ignored by construction. """ if summary is None: summary = load_summary(output_prefix) - packages = set(packages) cache = {} + for record in summary.get("container", []): + if not isinstance(record, dict): + continue + if not {"container", "result", "output_directory"} <= record.keys(): + logging.warning("Ignoring incomplete restart container record: %s", record) + continue + result = ( + TestExecutionOutcome.TEST_SUCCESS + if record["result"] + else TestExecutionOutcome.TEST_FAILURE + ) + cache[(CONTAINER_CACHE_SECTION, record["container"])] = TestResult( + build_stdouterr=None, + host_output_directory=_record_output_directory(output_prefix, record), + result=result, + stdouterr=None, + ) + + packages = set(packages) + if not packages: + return cache + for record in summary.get("versions", []): if not isinstance(record, dict): continue @@ -116,35 +125,11 @@ def version_result_cache_from_summary( versions = {package: record[package] for package in packages} repetition = int(record.get("test_repetition", 0)) key = version_cache_key(versions, repetition=repetition) - cache[key] = TestResult( + result_name = record["result"].rsplit(".", 1)[-1] + cache[(VERSION_CACHE_SECTION, key)] = TestResult( build_stdouterr=None, host_output_directory=_record_output_directory(output_prefix, record), - result=_parse_result(record["result"]), - stdouterr=None, - ) - return cache - - -def container_result_cache_from_summary( - output_prefix: pathlib.Path, - summary: typing.Optional[typing.Dict[str, typing.Any]] = None, -) -> typing.Dict[str, TestResult]: - """ - Reconstruct completed container-level test results from summary.json. - """ - if summary is None: - summary = load_summary(output_prefix) - cache = {} - for record in summary.get("container", []): - if not isinstance(record, dict): - continue - if not {"container", "result", "output_directory"} <= record.keys(): - logging.warning("Ignoring incomplete restart container record: %s", record) - continue - cache[record["container"]] = TestResult( - build_stdouterr=None, - host_output_directory=_record_output_directory(output_prefix, record), - result=_parse_result(record["result"]), + result=TestExecutionOutcome[result_name], stdouterr=None, ) return cache @@ -171,19 +156,13 @@ def create_output_symlinks( def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: if result is None: return - symlink_path = output_prefix / symlink_name - if symlink_path.exists() or symlink_path.is_symlink(): - assert symlink_path.resolve() == result.host_output_directory.resolve(), ( - symlink_path, - result.host_output_directory, - ) - return - absolute_symlink_path = symlink_path.resolve() - assert absolute_symlink_path.parent == result.host_output_directory.parent, ( - absolute_symlink_path, + symlink_path = (output_prefix / symlink_name).resolve() + assert not symlink_path.exists(), symlink_path + assert symlink_path.parent == result.host_output_directory.parent, ( + symlink_path, result.host_output_directory, ) - absolute_symlink_path.symlink_to(result.host_output_directory) + symlink_path.symlink_to(result.host_output_directory) symlink(last_known_good, "last-known-good") symlink(first_known_bad, "first-known-bad") diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 38386feb9..5396b03e1 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -20,10 +20,11 @@ from .versions import get_versions_dirs_env from .summary import ( add_summary_record, - container_result_cache_from_summary, + CONTAINER_CACHE_SECTION, create_output_symlinks, load_summary, - version_result_cache_from_summary, + result_cache_from_summary, + VERSION_CACHE_SECTION, ) from .bisect import get_commit_history from .utils import ( @@ -58,13 +59,13 @@ def __init__(self, args, logger): self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) self.check_success_before_failure = True self.restart_summary = ( - load_summary(self.args.restart_folder) if args.restart else {} + load_summary(self.args.restart_folder) if args.restart_folder else {} ) - self.container_result_cache = ( - container_result_cache_from_summary( - self.args.restart_folder, self.restart_summary + self.restart_cache = ( + result_cache_from_summary( + self.args.restart_folder, summary=self.restart_summary ) - if args.restart + if args.restart_folder else {} ) @@ -96,7 +97,7 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - if out_dir.exists() and self.args.restart: + if out_dir.exists() and self.args.restart_folder is not None: base_out_dir = out_dir n = 1 while out_dir.exists(): @@ -274,7 +275,7 @@ def _check_container_by_url( TestResult: The result of the test, including whether it passed and the output. """ before = time.monotonic() - cached_result = self.container_result_cache.get(container_url) + cached_result = self.restart_cache.get((CONTAINER_CACHE_SECTION, container_url)) if cached_result is not None: self.logger.info(f"Reusing cached container result for {container_url}") return cached_result @@ -714,14 +715,19 @@ def run_version_bisection( # Run the version-level bisection self.logger.info("Running version-level bisection...") result_cache = {} - preloaded_cache_keys = set() - if self.args.restart: - result_cache = version_result_cache_from_summary( - self.args.restart_folder, - package_versions.keys(), - self.restart_summary, + if self.args.restart_folder is not None: + self.restart_cache.update( + result_cache_from_summary( + self.args.restart_folder, + package_versions.keys(), + self.restart_summary, + ) ) - preloaded_cache_keys = set(result_cache.keys()) + result_cache = { + key: result + for (section, key), result in self.restart_cache.items() + if section == VERSION_CACHE_SECTION + } self.logger.info( f"Loaded {len(result_cache)} completed version-level result(s) " f"from {self.args.restart_folder / 'summary.json'}" @@ -735,7 +741,6 @@ def run_version_bisection( check_success_before_failure=self.check_success_before_failure, confirmation_iterations=self.args.confirmation_iterations, result_cache=result_cache, - preloaded_cache_keys=preloaded_cache_keys, ) except CouldNotReproduceFailure as e: if ( diff --git a/.github/triage/jax_toolbox_triage/utils.py b/.github/triage/jax_toolbox_triage/utils.py index 623c5d596..b8568563b 100644 --- a/.github/triage/jax_toolbox_triage/utils.py +++ b/.github/triage/jax_toolbox_triage/utils.py @@ -29,9 +29,6 @@ def get_logger(output_prefix: pathlib.Path, append: bool = False) -> logging.Log output_prefix.mkdir(exist_ok=append) logger = logging.getLogger("triage") logger.setLevel(logging.DEBUG) - for handler in logger.handlers: - handler.close() - logger.handlers.clear() formatter = logging.Formatter( fmt="[%(levelname)s] %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) diff --git a/.github/triage/tests/test_restart.py b/.github/triage/tests/test_restart.py index 52900f7b7..ff62ead0f 100644 --- a/.github/triage/tests/test_restart.py +++ b/.github/triage/tests/test_restart.py @@ -8,10 +8,13 @@ from jax_toolbox_triage.args import parse_args from jax_toolbox_triage.logic import ( TestExecutionOutcome, - TestResult, version_search, ) -from jax_toolbox_triage.summary import version_result_cache_from_summary +from jax_toolbox_triage.summary import ( + CONTAINER_CACHE_SECTION, + result_cache_from_summary, + VERSION_CACHE_SECTION, +) from jax_toolbox_triage.triage_tool import TriageTool @@ -33,33 +36,10 @@ def make_commits(): ) -def make_result(output_directory, outcome): - return TestResult( - build_stdouterr=None, - host_output_directory=output_directory, - result=outcome, - stdouterr=None, - ) - - -def test_restart_requires_summary_json(tmp_path): - with pytest.raises(Exception, match="--restart requires --restart-folder"): - parse_args( - [ - "--restart", - "--container-runtime=local", - "--passing-versions", - "jax:jax-good,xla:xla-good", - "--failing-versions", - "jax:jax-bad,xla:xla-good", - "test-command", - ] - ) - +def test_restart_folder_requires_summary_json(tmp_path): with pytest.raises(Exception, match="summary.json"): parse_args( [ - "--restart", "--restart-folder", str(tmp_path), "--container-runtime=local", @@ -77,7 +57,6 @@ def test_restart_uses_restart_folder_as_output_prefix(tmp_path): args = parse_args( [ - "--restart", "--restart-folder", str(tmp_path), "--container-runtime=local", @@ -116,7 +95,12 @@ def test_version_search_reuses_preloaded_restart_cache(tmp_path): ] } (tmp_path / "summary.json").write_text(json.dumps(summary)) - result_cache = version_result_cache_from_summary(tmp_path, make_commits().keys()) + summary_cache = result_cache_from_summary(tmp_path, make_commits().keys()) + result_cache = { + key: result + for (section, key), result in summary_cache.items() + if section == VERSION_CACHE_SECTION + } def build_and_test(**kwargs): raise AssertionError(f"Unexpected build/test during restart: {kwargs}") @@ -139,11 +123,31 @@ def build_and_test(**kwargs): assert first_known_bad.result == TestExecutionOutcome.TEST_FAILURE +def test_summary_cache_loads_container_records(tmp_path): + out_dir = tmp_path / "container-output" + out_dir.mkdir() + summary = { + "container": [ + { + "container": "container-url", + "output_directory": str(out_dir), + "result": True, + } + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + + summary_cache = result_cache_from_summary(tmp_path) + + cached_result = summary_cache[(CONTAINER_CACHE_SECTION, "container-url")] + assert cached_result.result == TestExecutionOutcome.TEST_SUCCESS + assert cached_result.host_output_directory == out_dir + + def test_restart_suffixed_stale_output_directory(tmp_path): (tmp_path / "summary.json").write_text("{}") args = parse_args( [ - "--restart", "--restart-folder", str(tmp_path), "--container-runtime=local", @@ -189,7 +193,6 @@ def test_triage_tool_loads_restart_cache(tmp_path): (tmp_path / "summary.json").write_text(json.dumps(summary)) args = parse_args( [ - "--restart", "--restart-folder", str(tmp_path), "--container-runtime=local",