Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions test/wycheproof/wycheproof_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def err(msg, **kwargs):
print(msg, file=sys.stderr, **kwargs)


def fail(msg):
err(msg)
sys.exit(1)


def require(condition, msg):
if not condition:
fail(msg)


def info(msg, **kwargs):
print(msg, **kwargs)

Expand Down Expand Up @@ -130,18 +140,22 @@ def check_sign_result(tc, out):
if "InvalidPrivateKey" in flags:
return TestResult.SKIPPED
# If new invalid-flag classes appear, fail loudly so we can handle them.
assert "IncorrectPrivateKeyLength" in flags or "InvalidContext" in flags, (
f"unhandled invalid flag(s) {flags} for tcId={tc['tcId']}"
require(
"IncorrectPrivateKeyLength" in flags or "InvalidContext" in flags,
f"unhandled invalid flag(s) {flags} for tcId={tc['tcId']}",
)
assert "decode_error" in out, (
f"expected decode_error on invalid tcId={tc['tcId']}"
require(
"decode_error" in out,
f"expected decode_error on invalid tcId={tc['tcId']}",
)
elif tc["result"] == "valid":
assert out["signature"].upper() == tc["sig"].upper(), (
f"signature mismatch tcId={tc['tcId']}"
require("signature" in out, f"missing signature for tcId={tc['tcId']}")
require(
out["signature"].upper() == tc["sig"].upper(),
f"signature mismatch tcId={tc['tcId']}",
)
else:
assert False, f"Unsupported test result '{tc['result']}' for tcId={tc['tcId']}"
fail(f"Unsupported test result '{tc['result']}' for tcId={tc['tcId']}")
return TestResult.OK


Expand Down Expand Up @@ -247,17 +261,17 @@ def run_verify_test(data_file):
)
if tc["result"] == "invalid":
# _error: non-zero exit code; decode_error: explicit validation failure
assert "_error" in out or "decode_error" in out, (
f"binary success on invalid tcId={tc['tcId']}"
require(
"_error" in out or "decode_error" in out,
f"binary success on invalid tcId={tc['tcId']}",
)
elif tc["result"] == "valid":
assert "_error" not in out and "decode_error" not in out, (
f"binary failure on valid tcId={tc['tcId']}"
require(
"_error" not in out and "decode_error" not in out,
f"binary failure on valid tcId={tc['tcId']}",
)
else:
assert False, (
f"Unsupported test result '{tc['result']}' for tcId={tc['tcId']}"
)
fail(f"Unsupported test result '{tc['result']}' for tcId={tc['tcId']}")
info("ok")
count += 1
info(f" {count} verify tests passed")
Expand All @@ -266,15 +280,19 @@ def run_verify_test(data_file):
def check_pk_from_sk_result(tc, tg, out):
flags = tc.get("flags", [])
if "IncorrectPrivateKeyLength" in flags:
assert "decode_error" in out, (
f"expected decode_error for IncorrectPrivateKeyLength tcId={tc['tcId']}"
require(
"decode_error" in out,
f"expected decode_error for IncorrectPrivateKeyLength tcId={tc['tcId']}",
)
elif "InvalidPrivateKey" in flags:
assert "_error" in out, f"pk_from_sk accepted invalid SK for tcId={tc['tcId']}"
require(
"_error" in out, f"pk_from_sk accepted invalid SK for tcId={tc['tcId']}"
)
else:
assert "pk" in out, f"pk_from_sk failed on valid SK for tcId={tc['tcId']}"
assert out["pk"].upper() == tg["publicKey"].upper(), (
f"pk_from_sk derived wrong PK for tcId={tc['tcId']}"
require("pk" in out, f"pk_from_sk failed on valid SK for tcId={tc['tcId']}")
require(
out["pk"].upper() == tg["publicKey"].upper(),
f"pk_from_sk derived wrong PK for tcId={tc['tcId']}",
)


Expand Down