diff --git a/src/cloudai/models/scenario.py b/src/cloudai/models/scenario.py index 276cac2e5..c0c5abf63 100644 --- a/src/cloudai/models/scenario.py +++ b/src/cloudai/models/scenario.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -102,7 +102,7 @@ def tdef_model_dump(self, by_alias: bool) -> dict: "test_template_name": self.test_template_name, "agent": self.agent, "agent_steps": self.agent_steps, - "agent_metrics": self.agent_metrics, + "agent_metrics": self.agent_metrics if "agent_metrics" in self.model_fields_set else None, "agent_reward_function": self.agent_reward_function, "extra_container_mounts": self.extra_container_mounts, "extra_env_vars": self.extra_env_vars if self.extra_env_vars else None, diff --git a/src/cloudai/registration.py b/src/cloudai/registration.py index f9be227e6..3b5bfc9af 100644 --- a/src/cloudai/registration.py +++ b/src/cloudai/registration.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -98,6 +98,7 @@ def register_all(): ) from cloudai.workloads.megatron_run import ( CheckpointTimingReportGenerationStrategy, + MegatronRunReportGenerationStrategy, MegatronRunSlurmCommandGenStrategy, MegatronRunTestDefinition, ) @@ -259,6 +260,7 @@ def register_all(): Registry().add_report(GPTTestDefinition, JaxToolboxReportGenerationStrategy) Registry().add_report(GrokTestDefinition, JaxToolboxReportGenerationStrategy) Registry().add_report(MegatronRunTestDefinition, CheckpointTimingReportGenerationStrategy) + Registry().add_report(MegatronRunTestDefinition, MegatronRunReportGenerationStrategy) Registry().add_report(MegatronBridgeTestDefinition, MegatronBridgeReportGenerationStrategy) Registry().add_report(NCCLTestDefinition, NcclTestPerformanceReportGenerationStrategy) Registry().add_report(NeMoLauncherTestDefinition, NeMoLauncherReportGenerationStrategy) diff --git a/src/cloudai/workloads/megatron_run/__init__.py b/src/cloudai/workloads/megatron_run/__init__.py index 960461256..1f4f1fec9 100644 --- a/src/cloudai/workloads/megatron_run/__init__.py +++ b/src/cloudai/workloads/megatron_run/__init__.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,12 +15,13 @@ # limitations under the License. from .megatron_run import MegatronRunCmdArgs, MegatronRunTestDefinition -from .report_generation_strategy import CheckpointTimingReportGenerationStrategy +from .report_generation_strategy import CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy from .slurm_command_gen_strategy import MegatronRunSlurmCommandGenStrategy __all__ = [ "CheckpointTimingReportGenerationStrategy", "MegatronRunCmdArgs", + "MegatronRunReportGenerationStrategy", "MegatronRunSlurmCommandGenStrategy", "MegatronRunTestDefinition", ] diff --git a/src/cloudai/workloads/megatron_run/report_generation_strategy.py b/src/cloudai/workloads/megatron_run/report_generation_strategy.py index 50723a2ca..df9e64bf1 100644 --- a/src/cloudai/workloads/megatron_run/report_generation_strategy.py +++ b/src/cloudai/workloads/megatron_run/report_generation_strategy.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,13 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import csv import logging import re +from pathlib import Path +from statistics import mean, median, pstdev +from typing import ClassVar -from cloudai.core import ReportGenerationStrategy +from cloudai.core import METRIC_ERROR, ReportGenerationStrategy CHECKPOINT_REGEX = re.compile(r"(save|load)-checkpoint\s.*:\s\((\d+\.\d+),\s(\d+\.\d+)\)") +# Pattern to match lines like: +# [2026-01-16 07:32:39] iteration 6/100 | ... | +# elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | ... +ITERATION_REGEX = re.compile( + r"elapsed time per iteration \(ms\):\s*([0-9]+(?:\.[0-9]+)?)" + r".*?" + r"throughput per GPU \(TFLOP/s/GPU\):\s*([0-9]+(?:\.[0-9]+)?)", + re.IGNORECASE, +) + class CheckpointTimingReportGenerationStrategy(ReportGenerationStrategy): """Strategy for generating reports from Checkpoint Timing test outputs.""" @@ -59,3 +75,112 @@ def generate_report(self) -> None: for checkpoint_type, timings in [("save", save_timings), ("load", load_timings)]: for t in timings: file.write(f"{checkpoint_type},{t[0]},{t[1]}\n") + + +class MegatronRunReportGenerationStrategy(ReportGenerationStrategy): + """Parse Megatron-Run stdout.txt for iteration time and GPU TFLOP/s per GPU.""" + + metrics: ClassVar[list[str]] = ["default", "iteration-time", "tflops-per-gpu"] + + def get_log_file(self) -> Path | None: + log = self.test_run.output_path / "stdout.txt" + return log if log.is_file() else None + + @property + def results_file(self) -> Path: + return self.get_log_file() or (self.test_run.output_path / "stdout.txt") + + def can_handle_directory(self) -> bool: + log_file = self.get_log_file() + if not log_file: + return False + with log_file.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + if ITERATION_REGEX.search(line): + return True + return False + + def _extract(self, log_path: Path) -> tuple[list[float], list[float]]: + """Extract iteration times (ms) and GPU TFLOPS from the log file.""" + iter_times_ms: list[float] = [] + gpu_tflops: list[float] = [] + with log_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + m = ITERATION_REGEX.search(line) + if m: + try: + iter_times_ms.append(float(m.group(1))) + gpu_tflops.append(float(m.group(2))) + except (ValueError, TypeError): + logging.debug("Failed to parse iteration metrics line: %s", line.rstrip("\n")) + + # Skip the first 20 iterations for statistics (to exclude warmup) + if len(iter_times_ms) > 20: + iter_times_ms = iter_times_ms[20:] + gpu_tflops = gpu_tflops[20:] + return iter_times_ms, gpu_tflops + + def _get_extracted_data(self) -> tuple[Path | None, list[float], list[float]]: + log_file = self.get_log_file() + if not log_file: + return None, [], [] + iter_times_ms, gpu_tflops = self._extract(log_file) + return log_file, iter_times_ms, gpu_tflops + + def generate_report(self) -> None: + log_file, iter_times_ms, gpu_tflops = self._get_extracted_data() + if not log_file: + logging.error( + "No stdout.txt file found in: %s", + self.test_run.output_path, + ) + return + + report_file = self.test_run.output_path / "megatron_run_report.csv" + if not iter_times_ms: + with report_file.open("w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["metric_type", "avg", "median", "min", "max", "std"]) + writer.writerow(["error: No iteration timing lines were found.", "", "", "", "", ""]) + logging.warning("No iteration metrics found under %s (wrote %s)", self.test_run.output_path, report_file) + return + + iter_avg = mean(iter_times_ms) + iter_median = median(iter_times_ms) + iter_min = min(iter_times_ms) + iter_max = max(iter_times_ms) + iter_std = pstdev(iter_times_ms) if len(iter_times_ms) > 1 else 0.0 + + if gpu_tflops: + tflops_avg = mean(gpu_tflops) + tflops_median = median(gpu_tflops) + tflops_min = min(gpu_tflops) + tflops_max = max(gpu_tflops) + tflops_std = pstdev(gpu_tflops) if len(gpu_tflops) > 1 else 0.0 + else: + tflops_avg = tflops_median = tflops_min = tflops_max = tflops_std = 0.0 + + with report_file.open("w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["metric_type", "avg", "median", "min", "max", "std"]) + writer.writerow(["iteration_time_ms", iter_avg, iter_median, iter_min, iter_max, iter_std]) + writer.writerow(["tflops_per_gpu", tflops_avg, tflops_median, tflops_min, tflops_max, tflops_std]) + + def get_metric(self, metric: str) -> float: + if metric not in {"default", "iteration-time", "tflops-per-gpu"}: + return METRIC_ERROR + log_file, iter_times_ms, gpu_tflops = self._get_extracted_data() + if not log_file: + logging.error( + "No stdout.txt file found in: %s", + self.test_run.output_path, + ) + return METRIC_ERROR + if not iter_times_ms: + return METRIC_ERROR + + if metric in {"default", "iteration-time"}: + return float(mean(iter_times_ms)) + if metric == "tflops-per-gpu": + return float(mean(gpu_tflops)) if gpu_tflops else METRIC_ERROR + return METRIC_ERROR diff --git a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py new file mode 100644 index 000000000..d472b6648 --- /dev/null +++ b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from pathlib import Path + +import pytest + +from cloudai import TestRun +from cloudai.core import METRIC_ERROR +from cloudai.systems.slurm.slurm_system import SlurmSystem +from cloudai.workloads.megatron_run import ( + MegatronRunCmdArgs, + MegatronRunReportGenerationStrategy, + MegatronRunTestDefinition, +) + + +@pytest.fixture +def megatron_run_tr(tmp_path: Path) -> TestRun: + test = MegatronRunTestDefinition( + name="megatron_run", + description="desc", + test_template_name="t", + cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)), + ) + tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path) + + stdout_content = ( + "[2026-01-16 07:32:24] iteration 5/ 100 | consumed samples: 10240 | " + "elapsed time per iteration (ms): 15800.0 | throughput per GPU (TFLOP/s/GPU): 490.0 | " + "learning rate: 4.134000E-07 | global batch size: 2048 | lm loss: 1.344240E+01 | " + "seq_load_balancing_loss: 1.000203E+00 | loss scale: 1.0 | grad norm: 2.870 | " + "num zeros: 1174412544.0 | params norm: 8660.607 | " + "number of skipped iterations: 0 | number of nan iterations: 0 |\n" + "[2026-01-16 07:32:39] iteration 6/ 100 | consumed samples: 12288 | " + "elapsed time per iteration (ms): 15639.0 | throughput per GPU (TFLOP/s/GPU): 494.6 | " + "learning rate: 4.180800E-07 | global batch size: 2048 | lm loss: 1.342407E+01 | " + "seq_load_balancing_loss: 1.000202E+00 | loss scale: 1.0 | grad norm: 2.867 | " + "num zeros: 1174412672.0 | params norm: 8660.606 | " + "number of skipped iterations: 0 | number of nan iterations: 0 |\n" + "[2026-01-16 07:32:54] iteration 7/ 100 | consumed samples: 14336 | " + "elapsed time per iteration (ms): 15448.5 | throughput per GPU (TFLOP/s/GPU): 500.6 | " + "learning rate: 4.227600E-07 | global batch size: 2048 | lm loss: 1.340574E+01 | " + "seq_load_balancing_loss: 1.000201E+00 | loss scale: 1.0 | grad norm: 2.864 | " + "num zeros: 1174412800.0 | params norm: 8660.605 | " + "number of skipped iterations: 0 | number of nan iterations: 0 |\n" + ) + (tr.output_path / "stdout.txt").write_text(stdout_content) + + return tr + + +@pytest.fixture +def megatron_run_tr_no_data(tmp_path: Path) -> TestRun: + test = MegatronRunTestDefinition( + name="megatron_run", + description="desc", + test_template_name="t", + cmd_args=MegatronRunCmdArgs(docker_image_url="http://url", run_script=Path(__file__)), + ) + tr = TestRun(name="megatron_run_test", test=test, num_nodes=1, nodes=[], output_path=tmp_path) + + stdout_content = """ +Some random log output without iteration metrics +Starting training... +""" + (tr.output_path / "stdout.txt").write_text(stdout_content) + + return tr + + +def test_megatron_run_can_handle_directory(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + assert strategy.can_handle_directory() + + +def test_megatron_run_cannot_handle_directory_without_iteration_data( + slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun +) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data) + assert not strategy.can_handle_directory() + + +def test_megatron_run_extract_and_generate_report(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + strategy.generate_report() + report_path = megatron_run_tr.output_path / "megatron_run_report.csv" + assert report_path.is_file() + + with report_path.open() as f: + reader = csv.DictReader(f) + rows = list(reader) + + # Should have 2 rows: iteration_time_ms and tflops_per_gpu + assert len(rows) == 2 + + expected_headers = {"metric_type", "avg", "median", "min", "max", "std"} + assert set(rows[0].keys()) == expected_headers + + data = {row["metric_type"]: row for row in rows} + + # Verify iteration_time_ms stats + assert "iteration_time_ms" in data + iter_stats = data["iteration_time_ms"] + expected_iter_avg = (15800.0 + 15639.0 + 15448.5) / 3 + assert abs(float(iter_stats["avg"]) - expected_iter_avg) < 0.1 + assert abs(float(iter_stats["median"]) - 15639.0) < 0.1 + assert abs(float(iter_stats["min"]) - 15448.5) < 0.1 + assert abs(float(iter_stats["max"]) - 15800.0) < 0.1 + + # Verify tflops_per_gpu stats + assert "tflops_per_gpu" in data + tflops_stats = data["tflops_per_gpu"] + expected_tflops_avg = (490.0 + 494.6 + 500.6) / 3 + assert abs(float(tflops_stats["avg"]) - expected_tflops_avg) < 0.1 + assert abs(float(tflops_stats["median"]) - 494.6) < 0.1 + assert abs(float(tflops_stats["min"]) - 490.0) < 0.1 + assert abs(float(tflops_stats["max"]) - 500.6) < 0.1 + + +def test_megatron_run_get_metric_iteration_time(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + # Expected: avg of [15800.0, 15639.0, 15448.5] + expected_avg = (15800.0 + 15639.0 + 15448.5) / 3 + metric = strategy.get_metric("iteration-time") + assert abs(metric - expected_avg) < 0.1 + + +def test_megatron_run_get_metric_default(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + # Default should return iteration-time + expected_avg = (15800.0 + 15639.0 + 15448.5) / 3 + metric = strategy.get_metric("default") + assert abs(metric - expected_avg) < 0.1 + + +def test_megatron_run_get_metric_tflops(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + # Expected: avg of [490.0, 494.6, 500.6] + expected_avg = (490.0 + 494.6 + 500.6) / 3 + metric = strategy.get_metric("tflops-per-gpu") + assert abs(metric - expected_avg) < 0.1 + + +def test_megatron_run_get_metric_invalid(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + metric = strategy.get_metric("invalid-metric") + assert metric == METRIC_ERROR + + +def test_megatron_run_get_metric_no_data(slurm_system: SlurmSystem, megatron_run_tr_no_data: TestRun) -> None: + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_data) + metric = strategy.get_metric("iteration-time") + assert metric == METRIC_ERROR + + +def test_megatron_run_metrics_class_var() -> None: + assert MegatronRunReportGenerationStrategy.metrics == ["default", "iteration-time", "tflops-per-gpu"] diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index c2af1373b..007acc100 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -53,6 +53,7 @@ from cloudai.workloads.megatron_run import ( CheckpointTimingReportGenerationStrategy, MegatronRunCmdArgs, + MegatronRunReportGenerationStrategy, MegatronRunTestDefinition, ) from cloudai.workloads.nccl_test import ( @@ -481,7 +482,10 @@ def test_default_reporters_size(self): (DeepEPTestDefinition, {DeepEPReportGenerationStrategy}), (GPTTestDefinition, {JaxToolboxReportGenerationStrategy}), (GrokTestDefinition, {JaxToolboxReportGenerationStrategy}), - (MegatronRunTestDefinition, {CheckpointTimingReportGenerationStrategy}), + ( + MegatronRunTestDefinition, + {CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy}, + ), (MegatronBridgeTestDefinition, {MegatronBridgeReportGenerationStrategy}), (NCCLTestDefinition, {NcclTestPerformanceReportGenerationStrategy}), (NeMoLauncherTestDefinition, {NeMoLauncherReportGenerationStrategy}),