Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions graph_net/fault_locator/bi_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ def bi_search(
predicator, # Signature: (ES, tolerance) -> bool
stoper, # Signature: (history_list) -> bool
tolerance=0,
) -> list[(int, bool)]:
) -> (list[(int, bool)], int):
"""
Binary Search Algorithm for Automatic Fault Location.
Binary Search Algorithm for Automatic Fault Location with Faulty Operator Detection.

This algorithm locates the first faulty operation in a computational graph
by iteratively narrowing the search range through graph truncation and
Expand All @@ -24,9 +24,12 @@ def bi_search(
tolerance (int): Numerical threshold for fault detection.

Returns:
list: Search history as a list of (split_point, is_fault) tuples.
tuple: (search_history, faulty_operator_index)
- search_history: list of (split_point, is_fault) tuples
- faulty_operator_index: index of the first faulty operator, or -1 if no fault found
"""
search_history = []
faulty_operator_index = -1 # Initialize as -1 meaning no fault found

# Initialize boundaries.
# 'high' usually represents the total number of operators in the graph.
Expand Down Expand Up @@ -73,7 +76,19 @@ def bi_search(
if not any(h[0] == low for h in search_history):
truncated_model_path = truncator(relative_model_path, low)
final_es = es_scores_calculator(evaluator(truncated_model_path))
search_history.append((low, predicator(final_es, tolerance)))
final_is_fault = predicator(final_es, tolerance)
search_history.append((low, final_is_fault))

if final_is_fault:
faulty_operator_index = low
break

return search_history
faulty_positions = [pos for pos, is_fault in search_history if is_fault]
if faulty_positions:
faulty_operator_index = min(faulty_positions)
faulty_model_path = truncator(relative_model_path, faulty_operator_index)
else:
faulty_operator_index = -1
faulty_model_path = ""

return search_history, faulty_operator_index, faulty_model_path
6 changes: 5 additions & 1 deletion graph_net/fault_locator/terminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ def __call__(self, history: list[(int, float)], high: int):
from pprint import pprint

pprint(history)
print(f"{high=}")
return bi_search_terminator(history, high)


def bi_search_terminator(history: list[(int, float)], high: int):
"""Stops when the search interval converges (range is 0 or 1)."""
last_idx, is_broken = history[-1]
if last_idx == 1 and is_broken:
return True
if last_idx == high and not is_broken:
return True
Comment on lines +14 to +18
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加边界停止条件

if len(history) == 1 and history[0][0] == high and not history[0][1]:
return True
if len(history) < 2:
Expand Down
107 changes: 107 additions & 0 deletions graph_net/fault_locator/torch/device_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys
import subprocess
import time
from pathlib import Path
from graph_net.declare_config_mixin import DeclareConfigMixin


class DeviceEvaluator(DeclareConfigMixin):
"""
Evaluator responsible for comparing model performance and accuracy between
a reference device (e.g., CPU) and a target device (e.g., CUDA).
Uses 'default' as the operator library for all target executions.
"""

def __init__(self, config=None):
self.init_config(config)

def declare_config(
self,
model_path_prefix: str,
output_dir: str,
ref_device: str = "cpu",
target_device: str = "cuda",
compiler: str = "nope",
):
"""
Configuration schema for cross-device benchmarking.
"""
pass

def __call__(self, rel_model_path: str) -> str:
"""
Orchestrates the evaluation pipeline:
1. Generates ground truth data on the reference device.
2. Validates performance/accuracy on the target device.
"""
output_path = Path(self.config["output_dir"])
full_model_path = Path(self.config["model_path_prefix"]) / rel_model_path

# Define specific workspace for target device logs
workspace = output_path / self.config["target_device"] / rel_model_path
workspace.mkdir(parents=True, exist_ok=True)

# Directory for sharing ground truth data between runs
reference_dir = output_path / "reference_data"
reference_dir.mkdir(parents=True, exist_ok=True)

log_file = workspace / "validation.log"

# Step 1: Execute reference test to establish baseline
print(f"Generating reference data on: {self.config['ref_device']}")
self._run_reference_test(full_model_path, reference_dir)

# Step 2: Execute target test and return captured logs
print(f"Running target evaluation on: {self.config['target_device']}")
return self._run_target_test(full_model_path, reference_dir, log_file)

def _run_reference_test(self, full_model_path: Path, reference_dir: Path):
"""
Invokes the reference module to generate expected outputs (Ground Truth).
"""
cmd = [
sys.executable,
"-m",
"graph_net.torch.test_reference_device",
"--model-path",
str(full_model_path),
"--reference-dir",
str(reference_dir),
"--compiler",
self.config["compiler"],
"--device",
self.config["ref_device"],
]
# Reference runs are silent; errors will raise a CalledProcessError
subprocess.run(cmd, check=True, capture_output=True, text=True)

def _run_target_test(
self, full_model_path: Path, reference_dir: Path, log_file: Path
) -> str:
"""
Executes the model on the target device using 'default' op_lib
and captures the full output log.
"""
cmd = [
sys.executable,
"-m",
"graph_net.torch.test_target_device",
"--model-path",
str(full_model_path),
"--reference-dir",
str(reference_dir),
"--device",
self.config["target_device"],
"--op-lib",
"default",
]

print(" ".join(cmd))
# Redirect all output to the log file for persistence and analysis
with log_file.open("w") as f:
start_time = time.perf_counter()
subprocess.run(cmd, stdout=f, stderr=subprocess.STDOUT, check=True)
end_time = time.perf_counter()
print(f"Target execution completed in {end_time - start_time:.4f} seconds")

return log_file.read_text()
14 changes: 13 additions & 1 deletion graph_net/sample_pass/auto_fault_bisearcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import graph_net
import shutil
from pathlib import Path
from typing import List, Tuple
from graph_net.sample_pass.sample_pass import SamplePass
Expand Down Expand Up @@ -81,7 +82,10 @@ def __call__(self, rel_model_path: str):
"""
# 2. Invoke the core binary search algorithm
# history type: list[tuple[int, bool]]
history: List[Tuple[int, bool]] = bi_search(
history: List[Tuple[int, bool]]
faulty_operator_index: int
faulty_model_path: str
history, faulty_operator_index, faulty_model_path = bi_search(
relative_model_path=rel_model_path,
truncator=self.truncator,
evaluator=self.evaluator,
Expand All @@ -100,13 +104,21 @@ def __call__(self, rel_model_path: str):
output_base.mkdir(parents=True, exist_ok=True)

result_file = output_base / file_name
test_file = (
Path(self.config["truncator_config"]["output_dir"]) / faulty_model_path
)

# Write history entries in the format: {truncate_size} {has_fault}
with result_file.open("w", encoding="utf-8") as f:
for trunc_size, has_fault in history:
f.write(f"{trunc_size} {has_fault}\n")

save_base = Path(self.config["output_dir"]) / "faulty_test"
save_base.mkdir(parents=True, exist_ok=True)
shutil.copytree(test_file, save_base / test_file.name, dirs_exist_ok=True)
Comment on lines +116 to +118
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持将错误/误差算子子图保存到faulty_test目录下

print(
f"[AutoFault] Search history for {rel_model_path} saved to: {result_file}"
)
print(f"First faulty operator index: {faulty_operator_index}")
print(f"Faulty operator model path: {test_file}")
return history
28 changes: 14 additions & 14 deletions graph_net/test/bi_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def mock_evaluator(self, sub_model_id):
def mock_predicator(self, es_scores, tolerance):
return any(score > tolerance for score in es_scores)

def mock_stoper(self, history):
def mock_stoper(self, history, high=None):
"""Stops when the search interval converges (range is 0 or 1)."""
if len(history) < 2:
return False
Expand All @@ -40,26 +40,21 @@ def test_bi_search_finds_correct_index(self):
truncator = self.mock_truncator()
setattr(truncator, "total_steps", 9)

history = bi_search(
model_path=self.model_path,
history, faulty_operator_index, faulty_model_path = bi_search(
relative_model_path=self.model_path,
truncator=truncator,
evaluator=self.mock_evaluator,
es_scores_calculator=lambda x: x, # Mock ES calculator
predicator=self.mock_predicator,
stoper=self.mock_stoper,
tolerance=0.5,
)

print(f"\nFault Test History: {history}")
print(f"Detected faulty operator index: {faulty_operator_index}")

# Filter history for all occurrences where a fault was detected
faulty_steps = [step for step in history if step[1] is True]

# The result of the fault localization is the minimum index with is_fault=True
if faulty_steps:
# Sort by index to find the first occurrence
actual_fault_index = min(faulty_steps, key=lambda x: x[0])[0]
else:
actual_fault_index = None
# The result of the fault localization is directly provided by the function
actual_fault_index = faulty_operator_index

print(f"\nIdentified Fault Index: {actual_fault_index}")
self.assertEqual(actual_fault_index, self.fault_index)
Expand All @@ -76,18 +71,23 @@ def clean_truncator(path, split_point):
def healthy_evaluator(sub_model_id):
return [0.01]

history = bi_search(
model_path=self.model_path,
history, faulty_operator_index, faulty_model_path = bi_search(
relative_model_path=self.model_path,
truncator=clean_truncator,
evaluator=healthy_evaluator,
es_scores_calculator=lambda x: x, # Mock ES calculator
predicator=self.mock_predicator,
stoper=self.mock_stoper,
tolerance=0.5,
)

print(f"No-Fault Test History: {history}")
print(f"Detected faulty operator index: {faulty_operator_index}")

# No fault should be detected
final_status = history[-1][1]
self.assertFalse(final_status)
self.assertEqual(faulty_operator_index, -1) # -1 indicates no fault found


if __name__ == "__main__":
Expand Down
38 changes: 38 additions & 0 deletions graph_net/test/device_fault_bisearcher_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash

# Resolve the root directory of the project
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

# Test Environment Setup
MODEL_LIST="$GRAPH_NET_ROOT/graph_net/test/small10_torch_samples_list.txt"
MODEL_PREFIX="$GRAPH_NET_ROOT"
OUTPUT_DIR="/tmp/workspace_auto_fault_bisearcher"

# Execute the SamplePass via the standard CLI entry point
python3 -m graph_net.apply_sample_pass \
--model-path-list "$MODEL_LIST" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/auto_fault_bisearcher.py" \
--sample-pass-class-name AutoFaultBisearcher \
--sample-pass-config $(base64 -w 0 <<EOF
{
"model_path_prefix": "$MODEL_PREFIX",
"output_dir": "$OUTPUT_DIR",
"output_file_name": "truncate_size_has_fault.txt",

"truncator_config": {
"model_path_prefix": "$MODEL_PREFIX",
"output_dir": "$OUTPUT_DIR/workspace_truncator/"
},
"evaluator_file_path": "$GRAPH_NET_ROOT/graph_net/fault_locator/torch/device_evaluator.py",
"evaluator_class_name": "DeviceEvaluator",
"evaluator_config": {
"model_path_prefix": "$OUTPUT_DIR/workspace_truncator/",
"output_dir": "$OUTPUT_DIR/device_evaluator",
"compiler": "nope",
"ref_device": "cpu",
"target_device": "cuda"
},
"tolerance": -9
}
EOF
)
19 changes: 16 additions & 3 deletions graph_net/torch/test_target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_single_model(args):
target_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(
target_log
)

eval_backend_diff.compare_correctness(ref_out, target_out, eval_args)
eval_backend_diff.compare_correctness(
list(flatten_tensor(ref_out)), list(flatten_tensor(target_out)), eval_args
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

部分中间结果返回是tuple(tensor),需要展开传入

)
test_compiler_util.print_times_and_speedup(args, ref_time_stats, target_time_stats)


Expand All @@ -85,6 +86,14 @@ def is_reference_log_exist(reference_dir, model_path):
return os.path.isfile(log_path)


def flatten_tensor(lst):
for i in lst:
if isinstance(i, (list, tuple)):
yield from flatten_tensor(i)
else:
yield i


def test_multi_models(args):
assert os.path.isdir(args.reference_dir)

Expand Down Expand Up @@ -144,7 +153,11 @@ def main(args):
)
else:
eval_backend_perf.register_op_lib(args.op_lib)

print(
f"[Processing] model_path: {args.model_path}",
file=sys.stderr,
flush=True,
)
test_single_model(args)
else:
test_multi_models(args)
Expand Down