Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/trace_tests/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

from trace_tests import __version__
from trace_tests.loader import LoadError, load_record
from trace_tests.modules.tr_env import DEFAULT_MAX_AGE_SECONDS
from trace_tests.result import Status
from trace_tests.runner import run


def _fmt_status(status: Status) -> str:
return status.value.upper().ljust(4)
return status.value.upper().ljust(10)


def _print_report(path: str, fmt: str, level: int, results: dict[str, list[Any]]) -> int:
Expand All @@ -27,6 +28,7 @@ def _print_report(path: str, fmt: str, level: int, results: dict[str, list[Any]]
failures = 0
skips = 0
passes = 0
unverified = 0

for module, findings in results.items():
for f in findings:
Expand All @@ -36,13 +38,26 @@ def _print_report(path: str, fmt: str, level: int, results: dict[str, list[Any]]
failures += 1
elif f.passed():
passes += 1
elif f.unverified():
unverified += 1
else:
skips += 1

total = passes + failures + skips
# Defense in depth: unverified findings must fail the run at any level that
# requires signatures, even if a module forgot to emit a hard FAIL.
if unverified and level >= 1:
failures += unverified

total = passes + failures + skips + (unverified if level == 0 else 0)
click.echo("")
if failures == 0:
click.echo(f"Result: PASS ({total} checks, {skips} skipped)")
if unverified:
click.echo(
f"Result: PASS ({total} checks, {skips} skipped, {unverified} UNVERIFIED "
f"-- record is NOT cryptographically verified)"
)
else:
click.echo(f"Result: PASS ({total} checks, {skips} skipped)")
return 0
else:
click.echo(f"Result: FAIL ({total} checks, {failures} failure(s), {skips} skipped)")
Expand All @@ -58,15 +73,23 @@ def main() -> None:
@main.command()
@click.option("--record", required=True, type=click.Path(), help="Path to the trust record (JSON)")
@click.option("--level", default=0, type=click.IntRange(0, 2), show_default=True, help="Conformance level to check (0, 1, or 2)")
def verify(record: str, level: int) -> None:
@click.option(
"--max-age",
"max_age",
default=DEFAULT_MAX_AGE_SECONDS,
type=click.IntRange(min=1),
show_default=True,
help="Maximum allowed record age in seconds (iat freshness window)",
)
def verify(record: str, level: int, max_age: int) -> None:
"""Verify a TRACE trust record against the conformance suite."""
try:
data, fmt = load_record(record)
except LoadError as exc:
click.echo(f"Error: {exc}", err=True)
sys.exit(2)

results = run(data, fmt, level)
results = run(data, fmt, level, max_age_seconds=max_age)
exit_code = _print_report(record, fmt, level, results)
sys.exit(exit_code)

Expand Down
26 changes: 23 additions & 3 deletions src/trace_tests/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ def load_record(path: str) -> tuple[dict[str, Any], str]:
"""Load a trust record from *path*.

Returns ``(record_dict, format_string)`` where format is one of:
- ``"cmcp-runtime"``: cmcp RuntimeClaim envelope (has ``gateway`` + ``trace`` + ``signature``)
- ``"cmcp-runtime"``: cmcp RuntimeClaim envelope (positive marker: ``cmcp_version``)
- ``"trace"``: canonical TRACE Trust Record (fields at top level)

Format detection is based on positive structural markers so an attacker cannot
downgrade a cmcp envelope to the weaker plain-trace path by stripping fields.
Records that look like partial cmcp envelopes are rejected outright.
"""
p = pathlib.Path(path)
if not p.exists():
Expand All @@ -30,8 +34,24 @@ def load_record(path: str) -> tuple[dict[str, Any], str]:
if not isinstance(data, dict):
raise LoadError("Record must be a JSON object")

fmt = "cmcp-runtime" if ("gateway" in data and "trace" in data) else "trace"
return data, fmt
if "cmcp_version" in data:
if not isinstance(data.get("trace"), dict):
raise LoadError(
"Record declares cmcp_version but has no 'trace' object; refusing malformed cmcp-runtime envelope"
)
return data, "cmcp-runtime"

# Envelope-only keys present without cmcp_version: this is a partial/stripped
# cmcp envelope, not a canonical TRACE record. Reject rather than silently
# downgrading to the weaker plain-trace verification path.
partial_markers = sorted(k for k in ("trace", "gateway", "signature") if k in data)
if partial_markers:
raise LoadError(
f"Record contains cmcp envelope field(s) {partial_markers} but no 'cmcp_version'; "
"refusing to treat a partial cmcp-runtime envelope as a plain trace record"
)

return data, "trace"


def extract_trace(record: dict[str, Any], fmt: str) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion src/trace_tests/modules/tr_anc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def check(trace: dict[str, Any]) -> list[Finding]:

try:
parsed = urlparse(transparency)
if parsed.scheme in ("https", "http") and parsed.netloc:
if parsed.scheme == "https" and parsed.netloc:
findings.append(Finding("TR-ANC-001", Status.PASS, f"transparency is a valid URI ({transparency[:80]})"))
else:
findings.append(Finding(
Expand Down
23 changes: 18 additions & 5 deletions src/trace_tests/modules/tr_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
_PROFILE = "tag:agentrust.io,2026:trace-v0.1"
_IAT_MIN = 1_700_000_000

#: Default maximum record age (seconds). Records older than this fail freshness.
DEFAULT_MAX_AGE_SECONDS = 24 * 60 * 60

def check(trace: dict[str, Any]) -> list[Finding]:
"""Return findings for the EAT envelope structure."""

def check(trace: dict[str, Any], max_age_seconds: int = DEFAULT_MAX_AGE_SECONDS) -> list[Finding]:
"""Return findings for the EAT envelope structure.

*max_age_seconds* bounds how old ``iat`` may be; without an upper bound any
historical record would pass freshness forever and be trivially replayable.
"""
findings: list[Finding] = []

profile = trace.get("eat_profile")
Expand All @@ -24,10 +31,16 @@ def check(trace: dict[str, Any]) -> list[Finding]:
iat = trace.get("iat")
if isinstance(iat, int) and iat >= _IAT_MIN:
now = int(time.time())
if iat <= now + 60:
findings.append(Finding("TR-ENV-002", Status.PASS, f"iat is valid ({iat})"))
else:
if iat > now + 60:
findings.append(Finding("TR-ENV-002", Status.FAIL, f"iat {iat} is in the future (now={now})"))
elif now - iat > max_age_seconds:
findings.append(Finding(
"TR-ENV-002", Status.FAIL,
f"TR-ENV-002: record is stale: iat {iat} is {now - iat}s old, "
f"exceeding the maximum allowed age of {max_age_seconds}s",
))
else:
findings.append(Finding("TR-ENV-002", Status.PASS, f"iat is valid and fresh ({iat})"))
else:
findings.append(Finding("TR-ENV-002", Status.FAIL, f"iat must be a Unix timestamp >= {_IAT_MIN}, got {iat!r}"))

Expand Down
4 changes: 2 additions & 2 deletions src/trace_tests/modules/tr_rte.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def check(trace: dict[str, Any]) -> list[Finding]:
rim_uri = runtime.get("rim_uri")
if rim_uri is None:
findings.append(Finding("TR-RTE-003", Status.SKIP, "runtime.rim_uri not present (optional)"))
elif isinstance(rim_uri, str) and rim_uri.startswith(("https://", "http://")):
findings.append(Finding("TR-RTE-003", Status.PASS, f"runtime.rim_uri is a URI ({rim_uri[:60]})"))
elif isinstance(rim_uri, str) and rim_uri.startswith("https://"):
findings.append(Finding("TR-RTE-003", Status.PASS, f"runtime.rim_uri is an https URI ({rim_uri[:60]})"))
else:
findings.append(Finding("TR-RTE-003", Status.FAIL, f"TR-RTE-003: runtime.rim_uri must be an https URI, got {rim_uri!r}"))

Expand Down
55 changes: 37 additions & 18 deletions src/trace_tests/modules/tr_sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
For cmcp-runtime records: Ed25519 over canonical JSON (sorted keys, no whitespace,
excluding the ``signature`` field). Key is in ``trace.cnf.jwk``.

For plain trace records: checks key type and algorithm alignment. Full JWS/COSE
verification is not yet implemented for this format.
For plain trace records no signature can be cryptographically verified, so TR-SIG
fails closed: at any level that requires signatures (level >= 1) the result is FAIL;
at level 0 the result is an explicit UNVERIFIED finding so the record can never be
reported as cryptographically verified.
"""

from __future__ import annotations
Expand Down Expand Up @@ -84,28 +86,45 @@ def check_cmcp_runtime(record: dict[str, Any]) -> list[Finding]:
return findings


def check(trace: dict[str, Any], record: dict[str, Any], fmt: str) -> list[Finding]:
"""Return TR-SIG findings. *record* is the full raw dict, *trace* is the extracted TRACE fields."""
def check(trace: dict[str, Any], record: dict[str, Any], fmt: str, level: int = 0) -> list[Finding]:
"""Return TR-SIG findings. *record* is the full raw dict, *trace* is the extracted TRACE fields.

*level* is the conformance level being checked. Plain trace records carry no
verifiable signature, so they FAIL at level >= 1 and are reported UNVERIFIED
(never PASS) at level 0.
"""
if fmt == "cmcp-runtime":
return check_cmcp_runtime(record)

# Plain trace format: verify key type alignment only; JWS/COSE verification not yet implemented.
# Plain trace format: no signature can be cryptographically verified.
findings: list[Finding] = []
jwk = trace.get("cnf", {}).get("jwk", {})
kty = jwk.get("kty")
crv = jwk.get("crv")

if kty in _SUPPORTED_KTY:
label = f"kty={kty!r}" + (f", crv={crv!r}" if crv else "")
return [
Finding("TR-SIG-004", Status.PASS, f"cnf.jwk key type is supported ({label})"),
Finding(
"TR-SIG-005",
Status.SKIP,
"Full JWS/COSE signature verification requires a signed EAT token (not a plain JSON record)",
),
]

if kty is None:
return [Finding("TR-SIG-004", Status.FAIL, "TR-SIG-004: cnf.jwk.kty is missing")]

return [Finding("TR-SIG-004", Status.FAIL, f"TR-SIG-004: unsupported key type {kty!r}; expected one of {sorted(_SUPPORTED_KTY)}")]
findings.append(Finding("TR-SIG-004", Status.PASS, f"cnf.jwk key type is supported ({label})"))
elif kty is None:
findings.append(Finding("TR-SIG-004", Status.FAIL, "TR-SIG-004: cnf.jwk.kty is missing"))
else:
findings.append(Finding(
"TR-SIG-004", Status.FAIL,
f"TR-SIG-004: unsupported key type {kty!r}; expected one of {sorted(_SUPPORTED_KTY)}",
))

if level >= 1:
findings.append(Finding(
"TR-SIG-005",
Status.FAIL,
f"TR-SIG-005: plain trace records carry no verifiable signature; "
f"Level {level} requires cryptographic signature verification (use a signed envelope, e.g. cmcp-runtime)",
))
else:
findings.append(Finding(
"TR-SIG-005",
Status.UNVERIFIED,
"TR-SIG-005: no signature present; this record is NOT cryptographically verified",
))

return findings
7 changes: 7 additions & 0 deletions src/trace_tests/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class Status(StrEnum):
PASS = "pass"
FAIL = "fail"
SKIP = "skip"
# No cryptographic verification was possible. Distinct from SKIP so callers
# can never mistake an unverified record for a benign omission. Treated as
# a failure at any conformance level that requires signatures (level >= 1).
UNVERIFIED = "unverified"


@dataclass
Expand All @@ -26,3 +30,6 @@ def failed(self) -> bool:

def skipped(self) -> bool:
return self.status == Status.SKIP

def unverified(self) -> bool:
return self.status == Status.UNVERIFIED
11 changes: 8 additions & 3 deletions src/trace_tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
}


def run(record: dict[str, Any], fmt: str, level: int) -> dict[str, list[Finding]]:
def run(
record: dict[str, Any],
fmt: str,
level: int,
max_age_seconds: int = tr_env.DEFAULT_MAX_AGE_SECONDS,
) -> dict[str, list[Finding]]:
"""Run all modules required for *level* and return findings keyed by module ID."""
if level not in _LEVEL_MODULES:
raise ValueError(f"Unknown conformance level {level!r}; valid: 0, 1, 2")
Expand All @@ -27,10 +32,10 @@ def run(record: dict[str, Any], fmt: str, level: int) -> dict[str, list[Finding]
active = set(_LEVEL_MODULES[level])

if "TR-ENV" in active:
results["TR-ENV"] = tr_env.check(trace)
results["TR-ENV"] = tr_env.check(trace, max_age_seconds=max_age_seconds)

if "TR-SIG" in active:
results["TR-SIG"] = tr_sig.check(trace, record, fmt)
results["TR-SIG"] = tr_sig.check(trace, record, fmt, level)

if "TR-POL" in active:
results["TR-POL"] = tr_pol.check(trace)
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""End-to-end CLI tests for fail-closed behavior."""

import json
import pathlib
import time

import pytest
from click.testing import CliRunner

from trace_tests.cli import main

VECTORS_DIR = pathlib.Path(__file__).parent.parent / "vectors"


@pytest.fixture
def fresh_level0_path(tmp_path):
vector = json.loads((VECTORS_DIR / "valid_level0.json").read_text())
vector["iat"] = int(time.time()) - 60
p = tmp_path / "record.json"
p.write_text(json.dumps(vector))
return str(p)


def test_unsigned_record_fails_level_2(fresh_level0_path):
"""Regression: unsigned plain JSON must never pass `verify --level 2`."""
result = CliRunner().invoke(main, ["verify", "--record", fresh_level0_path, "--level", "2"])
assert result.exit_code == 1, result.output
assert "Result: FAIL" in result.output


def test_unsigned_record_fails_level_1(fresh_level0_path):
result = CliRunner().invoke(main, ["verify", "--record", fresh_level0_path, "--level", "1"])
assert result.exit_code == 1, result.output


def test_unsigned_record_level_0_reports_unverified(fresh_level0_path):
result = CliRunner().invoke(main, ["verify", "--record", fresh_level0_path, "--level", "0"])
assert result.exit_code == 0, result.output
assert "UNVERIFIED" in result.output
assert "NOT cryptographically verified" in result.output


def test_stale_record_fails(tmp_path):
vector = json.loads((VECTORS_DIR / "valid_level0.json").read_text())
vector["iat"] = int(time.time()) - (25 * 3600)
p = tmp_path / "stale.json"
p.write_text(json.dumps(vector))
result = CliRunner().invoke(main, ["verify", "--record", str(p), "--level", "0"])
assert result.exit_code == 1, result.output


def test_partial_cmcp_envelope_is_rejected(tmp_path):
vector = json.loads((VECTORS_DIR / "valid_cmcp_runtime.json").read_text())
del vector["cmcp_version"]
p = tmp_path / "partial.json"
p.write_text(json.dumps(vector))
result = CliRunner().invoke(main, ["verify", "--record", str(p), "--level", "0"])
assert result.exit_code == 2, result.output
assert "partial cmcp-runtime envelope" in result.output
Loading
Loading