diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 729e88167..7cf3b7cfc 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -100,7 +100,7 @@ def parse_args(args=None) -> argparse.Namespace: ) parser.add_argument( "--output-prefix", - default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"), + default=None, help=""" Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS. An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is @@ -108,6 +108,15 @@ def parse_args(args=None) -> argparse.Namespace: PREFIX-summary.json""", type=pathlib.Path, ) + parser.add_argument( + "--restart-folder", + type=pathlib.Path, + help=""" + Restart a previous triage run from this output folder. The folder must + contain summary.json, which is used as the source of truth for completed + records. Incomplete output directories without summary records are ignored. + """, + ) parser.add_argument( "--skip-precondition-checks", action="store_true", @@ -308,6 +317,27 @@ def parse_args(args=None) -> argparse.Namespace: help="The name of the main branch (e.g. main) to derive cherry-picks from", ) args = parser.parse_args(args=args) + if args.restart_folder is not None: + args.restart_folder = args.restart_folder.resolve() + summary_file = args.restart_folder / "summary.json" + if not summary_file.exists(): + raise Exception( + f"--restart-folder must contain summary.json: {summary_file}" + ) + if ( + args.output_prefix is not None + and args.output_prefix.resolve() != args.restart_folder + ): + raise Exception( + "--output-prefix must match --restart-folder when restarting" + ) + args.output_prefix = args.restart_folder + else: + if args.output_prefix is None: + args.output_prefix = pathlib.Path( + datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S") + ) + assert args.container_runtime in { "docker", "pyxis", @@ -348,9 +378,7 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), ( - "For local runtime, --passing-versions and --failing-versions must be provided." - ) + ), "For local runtime, --passing-versions and --failing-versions must be provided." assert ( args.container is None and args.start_date is None @@ -395,9 +423,9 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert args.container is not None, ( - "--container must be passed for the container-level search" - ) + assert ( + args.container is not None + ), "--container must be passed for the container-level search" args.optional_software = optional_software.copy() if args.exclude_transformer_engine: diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index c63c43b6e..905c1167a 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -321,6 +321,17 @@ def remove_build_failures(ver): _REPETITION_KEY = "#rep" +def version_cache_key( + versions: typing.Dict[str, str], + *, + repetition: int = 0, +) -> FlatVersionDict: + return tuple( + sorted(_strip_build_failures(versions).items()) + + [(_REPETITION_KEY, str(repetition))] + ) + + def _version_search( *, versions: typing.OrderedDict[ @@ -342,28 +353,16 @@ def _version_search( versions, ) - def _cache_key( - versions: typing.Dict[str, str], - *, - repetition: int = 0, - ) -> FlatVersionDict: - return tuple( - sorted(_strip_build_failures(versions).items()) - + [(_REPETITION_KEY, str(repetition))] - ) - def build_cached( bisect_versions, *, - assert_miss: bool = False, test_output_log_level: int = logging.DEBUG, repetition: int = 0, ): - cache_key = _cache_key(bisect_versions, repetition=repetition) + cache_key = version_cache_key(bisect_versions, repetition=repetition) bisect_result = result_cache.get(cache_key) if bisect_result is not None: - if assert_miss: - raise Exception("Unexpected cache hit!") + logger.info(f"Reusing cached result for {dict(cache_key)}") return bisect_result bisect_result = build_and_test( versions=_strip_build_failures(bisect_versions), @@ -381,7 +380,6 @@ def _check_success(): for n in range(confirmation_iterations + 1): check_pass = build_cached( _earliest_versions(versions), - assert_miss=True, repetition=-n, ) if check_pass.result == TestExecutionOutcome.TEST_SUCCESS: @@ -403,7 +401,6 @@ def _check_failure(): for n in range(confirmation_iterations + 1): check_fail = build_cached( _latest_versions(versions), - assert_miss=True, repetition=-n, test_output_log_level=logging.INFO if n == 0 else logging.DEBUG, ) @@ -477,20 +474,20 @@ def find_successful_build(versions): # succeeds so we can do the same. The somewhat arbitrary logic here is to take # the shorter y??????n run and refine it in the hope one of the ? is a y. assert ( - result_cache.get(_cache_key(_earliest_versions(versions))).result + result_cache.get(version_cache_key(_earliest_versions(versions))).result == TestExecutionOutcome.TEST_SUCCESS ) build_statuses = [(True, {p: 0 for p in versions})] for n in range(1, len(versions[primary]) - 1): versions_n, indices_n = get_versions(primary_index=n, versions=versions) - result_n = result_cache.get(_cache_key(versions_n)) + result_n = result_cache.get(version_cache_key(versions_n)) assert ( result_n is None or result_n.result == TestExecutionOutcome.BUILD_FAILURE ) build_statuses.append((None if result_n is None else False, indices_n)) assert ( - result_cache.get(_cache_key(_latest_versions(versions))).result + result_cache.get(version_cache_key(_latest_versions(versions))).result == TestExecutionOutcome.TEST_FAILURE ) build_statuses.append((True, {p: len(vs) - 1 for p, vs in versions.items()})) @@ -565,9 +562,9 @@ def _index(pkg, ver): for package, index in indices.items(): versions[package] = versions[package][: index + 1] else: - assert bisect_result.result == TestExecutionOutcome.BUILD_FAILURE, ( - bisect_result - ) + assert ( + bisect_result.result == TestExecutionOutcome.BUILD_FAILURE + ), bisect_result # Did not succeed in finding a version of `primary` that builds. This does # not quite mean that all versions fail, as the algorithm will not try all # versions in ranges with failures at both ends @@ -586,11 +583,11 @@ def _index(pkg, ver): for n in range(1, n_primary - 1): versions_n, _ = get_versions(primary_index=n, versions=versions) # Should have been a build failure if tested. - result_n = result_cache.get(_cache_key(versions_n)) + result_n = result_cache.get(version_cache_key(versions_n)) if result_n is not None: - assert result_n.result == TestExecutionOutcome.BUILD_FAILURE, ( - result_n - ) + assert ( + result_n.result == TestExecutionOutcome.BUILD_FAILURE + ), result_n build_fail_commits.append(versions_n[primary]) logger.warning( f"Could not triage {primary} to a single version due to build " @@ -648,7 +645,7 @@ def _index(pkg, ver): # `blame` represents the last-known-good test result, first-known-bad was seen # earlier, or possibly not at all e.g. if `skip_precondition_checks` is True # and first-known-bad was the end of the search range. - first_known_bad_result = result_cache.get(_cache_key(first_known_bad)) + first_known_bad_result = result_cache.get(version_cache_key(first_known_bad)) if first_known_bad_result is None: if skip_precondition_checks: logger.info( @@ -665,13 +662,9 @@ def _index(pkg, ver): # `blame_versions` are last-known-good, `first_known_bad` are what they say. # We just tested `blame_versions`, so start with `first_known_bad`. for n in range(confirmation_iterations): - confirm_bad = build_cached( - first_known_bad, assert_miss=True, repetition=n + 1 - ) + confirm_bad = build_cached(first_known_bad, repetition=n + 1) assert confirm_bad.result == TestExecutionOutcome.TEST_FAILURE - confirm_good = build_cached( - blame_versions, assert_miss=True, repetition=n + 1 - ) + confirm_good = build_cached(blame_versions, repetition=n + 1) assert confirm_good.result == TestExecutionOutcome.TEST_SUCCESS return ret, blame, first_known_bad_result else: @@ -701,6 +694,7 @@ def version_search( skip_precondition_checks: bool, check_success_before_failure: bool = True, confirmation_iterations: int = 1, + result_cache: typing.Optional[typing.Dict[FlatVersionDict, TestResult]] = None, ) -> typing.Tuple[ typing.Dict[str, str], TestResult, @@ -728,6 +722,8 @@ def version_search( False, but the other fields can be used to obtain stdout+stderr and output files from those test invocations. """ + if result_cache is None: + result_cache = {} return _version_search( versions=versions, build_and_test=build_and_test, @@ -735,5 +731,5 @@ def version_search( skip_precondition_checks=skip_precondition_checks, check_success_before_failure=check_success_before_failure, confirmation_iterations=confirmation_iterations, - result_cache={}, + result_cache=result_cache, ) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0217b70f8..cbecb94a6 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -8,7 +8,10 @@ def main() -> None: Main entry point for the triage tool. """ args = parse_args() - logger = get_logger(args.output_prefix) + restart = args.restart_folder is not None + logger = get_logger(args.output_prefix, append=restart) + if restart: + logger.info(f"Restarting from {args.restart_folder}") tool = TriageTool(args, logger) passing_url, failing_url = tool.find_container_range() passing_versions, failing_versions = tool.gather_version_info( diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py index 8545d1447..c28d7be3b 100644 --- a/.github/triage/jax_toolbox_triage/summary.py +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -2,7 +2,15 @@ import logging import pathlib import typing -from .logic import TestResult +from .logic import ( + TestExecutionOutcome, + TestResult, + version_cache_key, +) + +SummaryCacheKey = typing.Tuple[str, typing.Any] +CONTAINER_CACHE_SECTION = "container" +VERSION_CACHE_SECTION = "versions" def add_summary_record( @@ -46,6 +54,87 @@ def add_summary_record( return data +def load_summary(output_prefix: pathlib.Path) -> typing.Dict[str, typing.Any]: + """ + Load the JSON summary for a previous or current triage run. + """ + with open(output_prefix / "summary.json", "r") as ifile: + return json.load(ifile) + + +def _record_output_directory( + output_prefix: pathlib.Path, record: typing.Dict[str, typing.Any] +) -> pathlib.Path: + out_dir = pathlib.Path(record["output_directory"]) + if out_dir.exists() or not out_dir.is_absolute(): + return out_dir + copied_dir = output_prefix / out_dir.name + if copied_dir.exists(): + return copied_dir.resolve() + return out_dir + + +def result_cache_from_summary( + output_prefix: pathlib.Path, + packages: typing.Iterable[str] = (), + summary: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Dict[SummaryCacheKey, TestResult]: + """ + Reconstruct completed build/test results from summary.json. + + The summary file is treated as the transaction log. Output directories that exist + without a corresponding summary record are ignored by construction. + """ + if summary is None: + summary = load_summary(output_prefix) + cache = {} + for record in summary.get("container", []): + if not isinstance(record, dict): + continue + if not {"container", "result", "output_directory"} <= record.keys(): + logging.warning("Ignoring incomplete restart container record: %s", record) + continue + result = ( + TestExecutionOutcome.TEST_SUCCESS + if record["result"] + else TestExecutionOutcome.TEST_FAILURE + ) + cache[(CONTAINER_CACHE_SECTION, record["container"])] = TestResult( + build_stdouterr=None, + host_output_directory=_record_output_directory(output_prefix, record), + result=result, + stdouterr=None, + ) + + packages = set(packages) + if not packages: + return cache + + for record in summary.get("versions", []): + if not isinstance(record, dict): + continue + if not packages <= record.keys(): + logging.warning( + "Ignoring restart summary record that is missing package keys: %s", + sorted(packages - record.keys()), + ) + continue + if "result" not in record or "output_directory" not in record: + logging.warning("Ignoring incomplete restart summary record: %s", record) + continue + versions = {package: record[package] for package in packages} + repetition = int(record.get("test_repetition", 0)) + key = version_cache_key(versions, repetition=repetition) + result_name = record["result"].rsplit(".", 1)[-1] + cache[(VERSION_CACHE_SECTION, key)] = TestResult( + build_stdouterr=None, + host_output_directory=_record_output_directory(output_prefix, record), + result=TestExecutionOutcome[result_name], + stdouterr=None, + ) + return cache + + def create_output_symlinks( output_prefix: pathlib.Path, last_known_good: typing.Optional[TestResult], diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 870dff61f..5396b03e1 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -18,7 +18,14 @@ CouldNotReproduceSuccess, ) from .versions import get_versions_dirs_env -from .summary import add_summary_record, create_output_symlinks +from .summary import ( + add_summary_record, + CONTAINER_CACHE_SECTION, + create_output_symlinks, + load_summary, + result_cache_from_summary, + VERSION_CACHE_SECTION, +) from .bisect import get_commit_history from .utils import ( container_url as container_url_base, @@ -51,6 +58,16 @@ def __init__(self, args, logger): self.packages_with_scripts = set() self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) self.check_success_before_failure = True + self.restart_summary = ( + load_summary(self.args.restart_folder) if args.restart_folder else {} + ) + self.restart_cache = ( + result_cache_from_summary( + self.args.restart_folder, summary=self.restart_summary + ) + if args.restart_folder + else {} + ) self.logger.info("Arguments:") for k, v in vars(self.args).items(): @@ -80,9 +97,13 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + if out_dir.exists() and self.args.restart_folder is not None: + base_out_dir = out_dir + n = 1 + while out_dir.exists(): + out_dir = base_out_dir.with_name(f"{base_out_dir.name}-restart-{n}") + n += 1 + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -126,9 +147,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." with self._make_container(container_url) as worker: url_versions, dirs, env = get_versions_dirs_env( @@ -254,6 +275,11 @@ def _check_container_by_url( TestResult: The result of the test, including whether it passed and the output. """ before = time.monotonic() + cached_result = self.restart_cache.get((CONTAINER_CACHE_SECTION, container_url)) + if cached_result is not None: + self.logger.info(f"Reusing cached container result for {container_url}") + return cached_result + out_dir = self._test_output_directory(container_url, None) with self._make_container( @@ -480,6 +506,8 @@ def _build_and_test( summary = { "build_time": build_time, "container": self.bisection_url, + "output_directory": out_dir.as_posix(), + "test_repetition": test_repetition, } summary.update(versions) if build_pass: @@ -686,6 +714,24 @@ def run_version_bisection( # Run the version-level bisection self.logger.info("Running version-level bisection...") + result_cache = {} + if self.args.restart_folder is not None: + self.restart_cache.update( + result_cache_from_summary( + self.args.restart_folder, + package_versions.keys(), + self.restart_summary, + ) + ) + result_cache = { + key: result + for (section, key), result in self.restart_cache.items() + if section == VERSION_CACHE_SECTION + } + self.logger.info( + f"Loaded {len(result_cache)} completed version-level result(s) " + f"from {self.args.restart_folder / 'summary.json'}" + ) try: result, last_known_good, first_known_bad = version_search( versions=package_versions, @@ -694,6 +740,7 @@ def run_version_bisection( skip_precondition_checks=self.args.skip_precondition_checks, check_success_before_failure=self.check_success_before_failure, confirmation_iterations=self.args.confirmation_iterations, + result_cache=result_cache, ) except CouldNotReproduceFailure as e: if ( diff --git a/.github/triage/jax_toolbox_triage/utils.py b/.github/triage/jax_toolbox_triage/utils.py index d18aea784..b8568563b 100644 --- a/.github/triage/jax_toolbox_triage/utils.py +++ b/.github/triage/jax_toolbox_triage/utils.py @@ -25,16 +25,17 @@ def container_url( return f"ghcr.io/nvidia/{container}:nightly-{date_str}" -def get_logger(output_prefix: pathlib.Path) -> logging.Logger: - output_prefix.mkdir() +def get_logger(output_prefix: pathlib.Path, append: bool = False) -> logging.Logger: + output_prefix.mkdir(exist_ok=append) logger = logging.getLogger("triage") logger.setLevel(logging.DEBUG) formatter = logging.Formatter( fmt="[%(levelname)s] %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) + mode = "a" if append else "w" console = logging.StreamHandler() - trace_file = logging.FileHandler(filename=output_prefix / "info.log", mode="w") - debug_file = logging.FileHandler(filename=output_prefix / "debug.log", mode="w") + trace_file = logging.FileHandler(filename=output_prefix / "info.log", mode=mode) + debug_file = logging.FileHandler(filename=output_prefix / "debug.log", mode=mode) console.setLevel(logging.INFO) trace_file.setLevel(logging.INFO) debug_file.setLevel(logging.DEBUG) diff --git a/.github/triage/tests/test_restart.py b/.github/triage/tests/test_restart.py new file mode 100644 index 000000000..ff62ead0f --- /dev/null +++ b/.github/triage/tests/test_restart.py @@ -0,0 +1,223 @@ +import collections +import datetime +import json +import logging + +import pytest + +from jax_toolbox_triage.args import parse_args +from jax_toolbox_triage.logic import ( + TestExecutionOutcome, + version_search, +) +from jax_toolbox_triage.summary import ( + CONTAINER_CACHE_SECTION, + result_cache_from_summary, + VERSION_CACHE_SECTION, +) +from jax_toolbox_triage.triage_tool import TriageTool + + +start_date = datetime.datetime(2026, 1, 1) + + +def make_commits(): + return collections.OrderedDict( + [ + ("xla", [("xla-good", start_date)]), + ( + "jax", + [ + ("jax-good", start_date), + ("jax-bad", start_date + datetime.timedelta(days=1)), + ], + ), + ] + ) + + +def test_restart_folder_requires_summary_json(tmp_path): + with pytest.raises(Exception, match="summary.json"): + parse_args( + [ + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + + +def test_restart_uses_restart_folder_as_output_prefix(tmp_path): + (tmp_path / "summary.json").write_text("{}") + + args = parse_args( + [ + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + + assert args.output_prefix == tmp_path.resolve() + + +def test_version_search_reuses_preloaded_restart_cache(tmp_path): + good_dir = tmp_path / "good" + bad_dir = tmp_path / "bad" + good_dir.mkdir() + bad_dir.mkdir() + summary = { + "versions": [ + { + "container": "container-url", + "xla": "xla-good", + "jax": "jax-good", + "output_directory": str(good_dir), + "result": "TestExecutionOutcome.TEST_SUCCESS", + }, + { + "container": "container-url", + "xla": "xla-good", + "jax": "jax-bad", + "output_directory": str(bad_dir), + "result": "TestExecutionOutcome.TEST_FAILURE", + }, + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + summary_cache = result_cache_from_summary(tmp_path, make_commits().keys()) + result_cache = { + key: result + for (section, key), result in summary_cache.items() + if section == VERSION_CACHE_SECTION + } + + def build_and_test(**kwargs): + raise AssertionError(f"Unexpected build/test during restart: {kwargs}") + + result, last_known_good, first_known_bad = version_search( + versions=make_commits(), + build_and_test=build_and_test, + logger=logging.getLogger("triage-restart-test"), + skip_precondition_checks=False, + confirmation_iterations=0, + result_cache=result_cache, + ) + + assert result == { + "jax_bad": "jax-bad", + "jax_good": "jax-good", + "xla_ref": "xla-good", + } + assert last_known_good.result == TestExecutionOutcome.TEST_SUCCESS + assert first_known_bad.result == TestExecutionOutcome.TEST_FAILURE + + +def test_summary_cache_loads_container_records(tmp_path): + out_dir = tmp_path / "container-output" + out_dir.mkdir() + summary = { + "container": [ + { + "container": "container-url", + "output_directory": str(out_dir), + "result": True, + } + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + + summary_cache = result_cache_from_summary(tmp_path) + + cached_result = summary_cache[(CONTAINER_CACHE_SECTION, "container-url")] + assert cached_result.result == TestExecutionOutcome.TEST_SUCCESS + assert cached_result.host_output_directory == out_dir + + +def test_restart_suffixed_stale_output_directory(tmp_path): + (tmp_path / "summary.json").write_text("{}") + args = parse_args( + [ + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "test-command", + ] + ) + tool = TriageTool(args, logging.getLogger("triage-restart-test")) + + stale = tool._test_output_directory("local", {"jax": "jax-good"}) + retry = tool._test_output_directory("local", {"jax": "jax-good"}) + + assert stale.name in retry.name + assert retry.name.endswith("-restart-1") + + +def test_triage_tool_loads_restart_cache(tmp_path): + good_dir = tmp_path / "good" + bad_dir = tmp_path / "bad" + good_dir.mkdir() + bad_dir.mkdir() + summary = { + "versions": [ + { + "container": "local", + "xla": "xla-good", + "jax": "jax-good", + "output_directory": str(good_dir), + "result": "TestExecutionOutcome.TEST_SUCCESS", + }, + { + "container": "local", + "xla": "xla-good", + "jax": "jax-bad", + "output_directory": str(bad_dir), + "result": "TestExecutionOutcome.TEST_FAILURE", + }, + ] + } + (tmp_path / "summary.json").write_text(json.dumps(summary)) + args = parse_args( + [ + "--restart-folder", + str(tmp_path), + "--container-runtime=local", + "--passing-versions", + "jax:jax-good,xla:xla-good", + "--failing-versions", + "jax:jax-bad,xla:xla-good", + "--confirmation-iterations=0", + "test-command", + ] + ) + tool = TriageTool(args, logging.getLogger("triage-restart-test")) + tool.bisection_url = "local" + tool._gather_histories = lambda worker, passing, failing: make_commits() + tool._check_installation_scripts = lambda worker: set() + + def build_and_test(**kwargs): + raise AssertionError(f"Unexpected build/test during restart: {kwargs}") + + tool._build_and_test = build_and_test + + result = tool.run_version_bisection( + {"jax": "jax-good", "xla": "xla-good"}, + {"jax": "jax-bad", "xla": "xla-good"}, + ) + + assert result["result"]["jax_good"] == "jax-good" + assert result["result"]["jax_bad"] == "jax-bad"