diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c763b3..aefe535 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # 0.3.1 +* Fixed some cases of comparisons that contain '-' or '+' against versions that are missing those parts. + # 0.3.0 * Added `skip_build` parameter to control whether comparisons adhere to Semver-10. diff --git a/bzl/versions/versions.bzl b/bzl/versions/versions.bzl index c8579cd..a1ab661 100644 --- a/bzl/versions/versions.bzl +++ b/bzl/versions/versions.bzl @@ -169,7 +169,7 @@ def _cmp(lhs, rhs): return 0 lhs = _maybe_int(lhs) rhs = _maybe_int(rhs) - if type(lhs) == "int" or type(rhs) == "string": + if type(lhs) == "int" or type(rhs) == "int": # Purely numeric identifiers have lower precedence than strings. if type(lhs) == "int" and type(rhs) == "string": return -1 @@ -225,20 +225,20 @@ def _version_cmp(version_lhs, version_rhs, *, skip_build = True): typ = type(version_rhs), ), ) - part = 0 - for part in range(min(len(lhs), len(rhs))): - if lhs in ["-", "+"] or rhs in ["-", "+"]: - part -= 1 # Since we later increase + for part in range(max(len(lhs), len(rhs))): + lhs_part = _at_or(lhs, part) + rhs_part = _at_or(rhs, part) + if lhs_part in ["-", "+"] or rhs_part in ["-", "+"]: + res = _extra_cmp(lhs_part, rhs_part) + if res != 0: + return res + continue + if part >= len(lhs) or part >= len(rhs): break - res = _cmp(lhs[part], rhs[part]) + res = _cmp(lhs_part, rhs_part) if res != 0: return res - part += 1 # Skip the already compared part even if that moves beyond end. - res = _extra_cmp(_at_or(lhs, part), _at_or(rhs, part)) - if res != 0: - return res - # All parts available on both sides are the same. return _cmp(len(lhs), len(rhs)) @@ -304,20 +304,22 @@ def _check_one_requirement(version, requirement, *, skip_build = True): At most 3 parts (major, minor, patch) plus the lengths are considered. """ if type(requirement) == "string": - if requirement.startswith(">="): - return _version_ge(version, requirement[2:].strip(), skip_build = skip_build) - elif requirement.startswith("<="): - return _version_le(version, requirement[2:].strip(), skip_build = skip_build) - elif requirement.startswith(">"): - return not _version_le(version, requirement[1:].strip(), skip_build = skip_build) - elif requirement.startswith("<"): - return not _version_ge(version, requirement[1:].strip(), skip_build = skip_build) - elif requirement.startswith("=="): - return _version_eq(version, requirement[2:].strip(), skip_build = skip_build) - elif requirement.startswith("!="): - return not _version_eq(version, requirement[2:].strip(), skip_build = skip_build) + if len(requirement) >= 2 and requirement[0:2] in [">=", "<=", "==", "!="]: + op = requirement[0:2] + rhs = requirement[2:].strip() + elif len(requirement) >= 1 and requirement[0] in [">", "<"]: + op = requirement[0] + rhs = requirement[1:].strip() else: - return _version_eq(version, requirement, skip_build = skip_build) + op = "==" + rhs = requirement.strip() + return _version_compare( + version, + op, + rhs, + skip_build = skip_build, + error = "Bad requirement: '{requirement}'.".format(requirement = requirement), + ) return _check_one_requirement_struct(version, requirement, skip_build = skip_build) def _check_all_requirements(version, requirements, *, skip_build = True): @@ -348,7 +350,7 @@ def _parse_split_requirement(req): elif req.startswith("!="): return struct(op = "!=", version = _parse_version(req[2:].strip())) else: - return struct(op = "==", version = _parse_version(req)) + return struct(op = "==", version = _parse_version(req.strip())) def _parse_requirements(requirements): """Splits the `requirements` string for use in `check_all_requirements`.""" diff --git a/bzl/versions/versions_test.bzl b/bzl/versions/versions_test.bzl index aab2459..9f724f1 100644 --- a/bzl/versions/versions_test.bzl +++ b/bzl/versions/versions_test.bzl @@ -403,7 +403,11 @@ def _versions_check_one_requirement_test(ctx): _assert_eq(env, versions.check_one_requirement([26], "42"), False) _assert_eq(env, versions.check_one_requirement(27, ">=26"), True) _assert_eq(env, versions.check_one_requirement(28, "<=26"), False) - + _assert_eq(env, versions.check_one_requirement([29, 0], "<=29"), False) + _assert_eq(env, versions.check_one_requirement([30, 0, 0], "<30"), False) + _assert_eq(env, versions.check_one_requirement([31, 0], ">31"), True) + _assert_eq(env, versions.check_one_requirement([31, 0], ">31"), True) + _assert_eq(env, versions.check_one_requirement([32, 0], ">=32.0-rc1"), True) return unittest.end(env) def _versions_check_all_requirements_test(ctx):