Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,23 @@ def parse_args(args=None) -> argparse.Namespace:
)
parser.add_argument(
"--output-prefix",
default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"),
default=None,
help="""
Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS.
An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is
written as PREFIX-debug.log, and a JSON summary is written as
PREFIX-summary.json""",
type=pathlib.Path,
)
parser.add_argument(
"--restart-folder",
type=pathlib.Path,
help="""
Restart a previous triage run from this output folder. The folder must
contain summary.json, which is used as the source of truth for completed
records. Incomplete output directories without summary records are ignored.
""",
)
parser.add_argument(
"--skip-precondition-checks",
action="store_true",
Expand Down Expand Up @@ -308,6 +317,27 @@ def parse_args(args=None) -> argparse.Namespace:
help="The name of the main branch (e.g. main) to derive cherry-picks from",
)
args = parser.parse_args(args=args)
if args.restart_folder is not None:
args.restart_folder = args.restart_folder.resolve()
summary_file = args.restart_folder / "summary.json"
if not summary_file.exists():
raise Exception(
f"--restart-folder must contain summary.json: {summary_file}"
)
if (
args.output_prefix is not None
and args.output_prefix.resolve() != args.restart_folder
):
raise Exception(
"--output-prefix must match --restart-folder when restarting"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these semantics, isn't a boolean --restart sufficient?

)
args.output_prefix = args.restart_folder
else:
if args.output_prefix is None:
args.output_prefix = pathlib.Path(
datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S")
)

assert args.container_runtime in {
"docker",
"pyxis",
Expand Down Expand Up @@ -348,9 +378,7 @@ def parse_args(args=None) -> argparse.Namespace:
if args.container_runtime == "local":
assert (
args.passing_versions is not None and args.failing_versions is not None
), (
"For local runtime, --passing-versions and --failing-versions must be provided."
)
), "For local runtime, --passing-versions and --failing-versions must be provided."
assert (
args.container is None
and args.start_date is None
Expand Down Expand Up @@ -395,9 +423,9 @@ def parse_args(args=None) -> argparse.Namespace:
else:
# None of --{passing,failing}-{versions,container} were passed, make sure the
# compulsory arguments for the container-level search were passed
assert args.container is not None, (
"--container must be passed for the container-level search"
)
assert (
args.container is not None
), "--container must be passed for the container-level search"

args.optional_software = optional_software.copy()
if args.exclude_transformer_engine:
Expand Down
64 changes: 30 additions & 34 deletions .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def remove_build_failures(ver):
_REPETITION_KEY = "#rep"


def version_cache_key(
versions: typing.Dict[str, str],
*,
repetition: int = 0,
) -> FlatVersionDict:
return tuple(
sorted(_strip_build_failures(versions).items())
+ [(_REPETITION_KEY, str(repetition))]
)


def _version_search(
*,
versions: typing.OrderedDict[
Expand All @@ -342,28 +353,16 @@ def _version_search(
versions,
)

def _cache_key(
versions: typing.Dict[str, str],
*,
repetition: int = 0,
) -> FlatVersionDict:
return tuple(
sorted(_strip_build_failures(versions).items())
+ [(_REPETITION_KEY, str(repetition))]
)

def build_cached(
bisect_versions,
*,
assert_miss: bool = False,
test_output_log_level: int = logging.DEBUG,
repetition: int = 0,
):
cache_key = _cache_key(bisect_versions, repetition=repetition)
cache_key = version_cache_key(bisect_versions, repetition=repetition)
bisect_result = result_cache.get(cache_key)
if bisect_result is not None:
if assert_miss:
raise Exception("Unexpected cache hit!")
logger.info(f"Reusing cached result for {dict(cache_key)}")
return bisect_result
bisect_result = build_and_test(
versions=_strip_build_failures(bisect_versions),
Expand All @@ -381,7 +380,6 @@ def _check_success():
for n in range(confirmation_iterations + 1):
check_pass = build_cached(
_earliest_versions(versions),
assert_miss=True,
repetition=-n,
)
if check_pass.result == TestExecutionOutcome.TEST_SUCCESS:
Expand All @@ -403,7 +401,6 @@ def _check_failure():
for n in range(confirmation_iterations + 1):
check_fail = build_cached(
_latest_versions(versions),
assert_miss=True,
repetition=-n,
test_output_log_level=logging.INFO if n == 0 else logging.DEBUG,
)
Expand Down Expand Up @@ -477,20 +474,20 @@ def find_successful_build(versions):
# succeeds so we can do the same. The somewhat arbitrary logic here is to take
# the shorter y??????n run and refine it in the hope one of the ? is a y.
assert (
result_cache.get(_cache_key(_earliest_versions(versions))).result
result_cache.get(version_cache_key(_earliest_versions(versions))).result
== TestExecutionOutcome.TEST_SUCCESS
)
build_statuses = [(True, {p: 0 for p in versions})]
for n in range(1, len(versions[primary]) - 1):
versions_n, indices_n = get_versions(primary_index=n, versions=versions)
result_n = result_cache.get(_cache_key(versions_n))
result_n = result_cache.get(version_cache_key(versions_n))
assert (
result_n is None
or result_n.result == TestExecutionOutcome.BUILD_FAILURE
)
build_statuses.append((None if result_n is None else False, indices_n))
assert (
result_cache.get(_cache_key(_latest_versions(versions))).result
result_cache.get(version_cache_key(_latest_versions(versions))).result
== TestExecutionOutcome.TEST_FAILURE
)
build_statuses.append((True, {p: len(vs) - 1 for p, vs in versions.items()}))
Expand Down Expand Up @@ -565,9 +562,9 @@ def _index(pkg, ver):
for package, index in indices.items():
versions[package] = versions[package][: index + 1]
else:
assert bisect_result.result == TestExecutionOutcome.BUILD_FAILURE, (
bisect_result
)
assert (
bisect_result.result == TestExecutionOutcome.BUILD_FAILURE
), bisect_result
# Did not succeed in finding a version of `primary` that builds. This does
# not quite mean that all versions fail, as the algorithm will not try all
# versions in ranges with failures at both ends
Expand All @@ -586,11 +583,11 @@ def _index(pkg, ver):
for n in range(1, n_primary - 1):
versions_n, _ = get_versions(primary_index=n, versions=versions)
# Should have been a build failure if tested.
result_n = result_cache.get(_cache_key(versions_n))
result_n = result_cache.get(version_cache_key(versions_n))
if result_n is not None:
assert result_n.result == TestExecutionOutcome.BUILD_FAILURE, (
result_n
)
assert (
result_n.result == TestExecutionOutcome.BUILD_FAILURE
), result_n
build_fail_commits.append(versions_n[primary])
logger.warning(
f"Could not triage {primary} to a single version due to build "
Expand Down Expand Up @@ -648,7 +645,7 @@ def _index(pkg, ver):
# `blame` represents the last-known-good test result, first-known-bad was seen
# earlier, or possibly not at all e.g. if `skip_precondition_checks` is True
# and first-known-bad was the end of the search range.
first_known_bad_result = result_cache.get(_cache_key(first_known_bad))
first_known_bad_result = result_cache.get(version_cache_key(first_known_bad))
if first_known_bad_result is None:
if skip_precondition_checks:
logger.info(
Expand All @@ -665,13 +662,9 @@ def _index(pkg, ver):
# `blame_versions` are last-known-good, `first_known_bad` are what they say.
# We just tested `blame_versions`, so start with `first_known_bad`.
for n in range(confirmation_iterations):
confirm_bad = build_cached(
first_known_bad, assert_miss=True, repetition=n + 1
)
confirm_bad = build_cached(first_known_bad, repetition=n + 1)
assert confirm_bad.result == TestExecutionOutcome.TEST_FAILURE
confirm_good = build_cached(
blame_versions, assert_miss=True, repetition=n + 1
)
confirm_good = build_cached(blame_versions, repetition=n + 1)
assert confirm_good.result == TestExecutionOutcome.TEST_SUCCESS
return ret, blame, first_known_bad_result
else:
Expand Down Expand Up @@ -701,6 +694,7 @@ def version_search(
skip_precondition_checks: bool,
check_success_before_failure: bool = True,
confirmation_iterations: int = 1,
result_cache: typing.Optional[typing.Dict[FlatVersionDict, TestResult]] = None,
) -> typing.Tuple[
typing.Dict[str, str],
TestResult,
Expand Down Expand Up @@ -728,12 +722,14 @@ def version_search(
False, but the other fields can be used to obtain stdout+stderr and
output files from those test invocations.
"""
if result_cache is None:
result_cache = {}
return _version_search(
versions=versions,
build_and_test=build_and_test,
logger=logger,
skip_precondition_checks=skip_precondition_checks,
check_success_before_failure=check_success_before_failure,
confirmation_iterations=confirmation_iterations,
result_cache={},
result_cache=result_cache,
)
5 changes: 4 additions & 1 deletion .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ def main() -> None:
Main entry point for the triage tool.
"""
args = parse_args()
logger = get_logger(args.output_prefix)
restart = args.restart_folder is not None
logger = get_logger(args.output_prefix, append=restart)
if restart:
logger.info(f"Restarting from {args.restart_folder}")
tool = TriageTool(args, logger)
passing_url, failing_url = tool.find_container_range()
passing_versions, failing_versions = tool.gather_version_info(
Expand Down
91 changes: 90 additions & 1 deletion .github/triage/jax_toolbox_triage/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import logging
import pathlib
import typing
from .logic import TestResult
from .logic import (
TestExecutionOutcome,
TestResult,
version_cache_key,
)

SummaryCacheKey = typing.Tuple[str, typing.Any]
CONTAINER_CACHE_SECTION = "container"
VERSION_CACHE_SECTION = "versions"


def add_summary_record(
Expand Down Expand Up @@ -46,6 +54,87 @@ def add_summary_record(
return data


def load_summary(output_prefix: pathlib.Path) -> typing.Dict[str, typing.Any]:
"""
Load the JSON summary for a previous or current triage run.
"""
with open(output_prefix / "summary.json", "r") as ifile:
return json.load(ifile)


def _record_output_directory(
output_prefix: pathlib.Path, record: typing.Dict[str, typing.Any]
) -> pathlib.Path:
out_dir = pathlib.Path(record["output_directory"])
if out_dir.exists() or not out_dir.is_absolute():
return out_dir
copied_dir = output_prefix / out_dir.name
if copied_dir.exists():
return copied_dir.resolve()
return out_dir


def result_cache_from_summary(
output_prefix: pathlib.Path,
packages: typing.Iterable[str] = (),
summary: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> typing.Dict[SummaryCacheKey, TestResult]:
"""
Reconstruct completed build/test results from summary.json.

The summary file is treated as the transaction log. Output directories that exist
without a corresponding summary record are ignored by construction.
"""
if summary is None:
summary = load_summary(output_prefix)
cache = {}
for record in summary.get("container", []):
if not isinstance(record, dict):
continue
if not {"container", "result", "output_directory"} <= record.keys():
logging.warning("Ignoring incomplete restart container record: %s", record)
continue
result = (
TestExecutionOutcome.TEST_SUCCESS
if record["result"]
else TestExecutionOutcome.TEST_FAILURE
)
cache[(CONTAINER_CACHE_SECTION, record["container"])] = TestResult(
build_stdouterr=None,
host_output_directory=_record_output_directory(output_prefix, record),
result=result,
stdouterr=None,
)

packages = set(packages)
if not packages:
return cache

for record in summary.get("versions", []):
if not isinstance(record, dict):
continue
if not packages <= record.keys():
logging.warning(
"Ignoring restart summary record that is missing package keys: %s",
sorted(packages - record.keys()),
)
continue
if "result" not in record or "output_directory" not in record:
logging.warning("Ignoring incomplete restart summary record: %s", record)
continue
versions = {package: record[package] for package in packages}
repetition = int(record.get("test_repetition", 0))
key = version_cache_key(versions, repetition=repetition)
result_name = record["result"].rsplit(".", 1)[-1]
cache[(VERSION_CACHE_SECTION, key)] = TestResult(
build_stdouterr=None,
host_output_directory=_record_output_directory(output_prefix, record),
result=TestExecutionOutcome[result_name],
stdouterr=None,
)
return cache


def create_output_symlinks(
output_prefix: pathlib.Path,
last_known_good: typing.Optional[TestResult],
Expand Down
Loading