diff --git a/compare.go b/compare.go index 802b860..ea5e7aa 100644 --- a/compare.go +++ b/compare.go @@ -152,6 +152,88 @@ func Only(platform specs.Platform) MatchComparer { return Ordered(platformVector(Normalize(platform))...) } +// OnlyOS returns a match comparer that matches only platforms with the same +// OS, OS version, and OS features, regardless of architecture. When comparing, +// it always ranks the best architecture match highest using the default +// platform resolution logic. +func OnlyOS(platform specs.Platform) MatchComparer { + normalized := Normalize(platform) + return onlyOSComparer{ + platform: normalized, + osvM: newOSVersionMatcher(normalized), + archOrder: orderedPlatformComparer{ + matchers: []Matcher{NewMatcher(normalized)}, + }, + } +} + +func newOSVersionMatcher(platform specs.Platform) osVerMatcher { + if platform.OS == "windows" { + return &windowsVersionMatcher{ + windowsOSVersion: getWindowsOSVersion(platform.OSVersion), + } + } + return nil +} + +type onlyOSComparer struct { + platform specs.Platform + osvM osVerMatcher + archOrder orderedPlatformComparer +} + +func (c onlyOSComparer) matchOS(platform specs.Platform) bool { + normalized := Normalize(platform) + if c.platform.OS != normalized.OS { + return false + } + if c.osvM != nil { + if !c.osvM.Match(platform.OSVersion) { + return false + } + } + if len(normalized.OSFeatures) > 0 { + if len(c.platform.OSFeatures) < len(normalized.OSFeatures) { + return false + } + j := 0 + for _, feature := range normalized.OSFeatures { + found := false + for ; j < len(c.platform.OSFeatures); j++ { + if feature == c.platform.OSFeatures[j] { + found = true + j++ + break + } + if feature < c.platform.OSFeatures[j] { + return false + } + } + if !found { + return false + } + } + } + return true +} + +func (c onlyOSComparer) Match(platform specs.Platform) bool { + return c.matchOS(platform) +} + +func (c onlyOSComparer) Less(p1, p2 specs.Platform) bool { + p1m := c.matchOS(p1) + p2m := c.matchOS(p2) + if p1m && !p2m { + return true + } + if !p1m { + return false + } + // Both match — rank by architecture preference + return c.archOrder.Less(p1, p2) +} + // OnlyStrict returns a match comparer for a single platform. // // Unlike Only, OnlyStrict does not match sub platforms. diff --git a/compare_test.go b/compare_test.go index a5fddc1..1968e51 100644 --- a/compare_test.go +++ b/compare_test.go @@ -595,6 +595,126 @@ func TestOnlyStrict(t *testing.T) { } } +func TestOnlyOS(t *testing.T) { + for _, tc := range []struct { + platform string + matches map[bool][]string + }{ + { + platform: "linux/amd64", + matches: map[bool][]string{ + true: { + "linux/amd64", + "linux/arm64", + "linux/arm/v7", + "linux/386", + }, + false: { + "windows/amd64", + "darwin/arm64", + }, + }, + }, + { + platform: "windows(10.0.17763)/amd64", + matches: map[bool][]string{ + true: { + "windows/amd64", + "windows/arm64", + "windows(10.0.17763)/amd64", + "windows(10.0.17763)/arm64", + }, + false: { + "linux/amd64", + }, + }, + }, + { + platform: "linux(+gpu)/amd64", + matches: map[bool][]string{ + true: { + "linux/arm64", + "linux(+gpu)/amd64", + }, + false: { + "windows/amd64", + "linux(+gpu+simd)/amd64", + }, + }, + }, + } { + testcase := tc + t.Run(testcase.platform, func(t *testing.T) { + p, err := Parse(testcase.platform) + if err != nil { + t.Fatal(err) + } + m := OnlyOS(p) + for shouldMatch, platforms := range testcase.matches { + for _, matchPlatform := range platforms { + mp, err := Parse(matchPlatform) + if err != nil { + t.Fatal(err) + } + if match := m.Match(mp); shouldMatch != match { + t.Errorf("OnlyOS(%q).Match(%q) should return %v, but returns %v", testcase.platform, matchPlatform, shouldMatch, match) + } + } + } + }) + } +} + +func TestOnlyOSLess(t *testing.T) { + for _, tc := range []struct { + platform string + platforms []string + expected []string + }{ + { + // Exact architecture match ranks first, others are unordered but before non-OS matches + platform: "linux/amd64", + platforms: []string{"linux/arm64", "linux/386", "linux/amd64", "windows/amd64"}, + expected: []string{"linux/amd64", "linux/arm64", "linux/386", "windows/amd64"}, + }, + { + // Strict: only exact arch/variant match is preferred + platform: "linux/arm64", + platforms: []string{"linux/amd64", "linux/arm/v7", "linux/arm64", "windows/arm64"}, + expected: []string{"linux/arm64", "linux/amd64", "linux/arm/v7", "windows/arm64"}, + }, + { + // Non-matching OS should always sort last + platform: "linux/amd64", + platforms: []string{"windows/amd64", "darwin/amd64", "linux/arm64", "linux/amd64"}, + expected: []string{"linux/amd64", "linux/arm64", "windows/amd64", "darwin/amd64"}, + }, + } { + testcase := tc + t.Run(testcase.platform, func(t *testing.T) { + p, err := Parse(testcase.platform) + if err != nil { + t.Fatal(err) + } + mc := OnlyOS(p) + platforms, err := ParseAll(testcase.platforms) + if err != nil { + t.Fatal(err) + } + sort.Slice(platforms, func(i, j int) bool { + return mc.Less(platforms[i], platforms[j]) + }) + actual := make([]string, len(platforms)) + for i, ps := range platforms { + actual[i] = FormatAll(ps) + } + if !reflect.DeepEqual(testcase.expected, actual) { + t.Errorf("Wrong platform order:\nExpected: %#v\nActual: %#v", testcase.expected, actual) + } + }) + } +} + func TestCompareOSFeatures(t *testing.T) { for _, tc := range []struct { platform string