diff --git a/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification.cue b/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification.cue new file mode 100644 index 0000000..83b9f44 --- /dev/null +++ b/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification.cue @@ -0,0 +1,13 @@ + +@if(incorrect_unification) +package tests + +#request: observed: composite: resource: { + foo: "bar" +} + +response: { + desired: resources: main: resource: { + foo: "baz" + } +} @assertionMode(unification) diff --git a/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification_extra.cue b/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification_extra.cue new file mode 100644 index 0000000..3a8c224 --- /dev/null +++ b/internal/cuetools/testdata/assertionmodes/tests/incorrect_unification_extra.cue @@ -0,0 +1,15 @@ + +@if(incorrect_unification_extra) +package tests + +#request: observed: composite: resource: { + foo: "bar" +} + +#mode: "unification" + +response: { + desired: resources: main: resource: { + extra: "value" + } +} @assertionMode(unification) diff --git a/internal/cuetools/testdata/assertionmodes/tests/unification.cue b/internal/cuetools/testdata/assertionmodes/tests/unification.cue new file mode 100644 index 0000000..7db962b --- /dev/null +++ b/internal/cuetools/testdata/assertionmodes/tests/unification.cue @@ -0,0 +1,12 @@ +@if(unification) +package tests + +#request: observed: composite: resource: { + foo: "bar" +} + +response: { + desired: resources: main: resource: { + foo: "bar" + } +} @assertionMode(unification) diff --git a/internal/cuetools/tester.go b/internal/cuetools/tester.go index ef3427a..b50689c 100644 --- a/internal/cuetools/tester.go +++ b/internal/cuetools/tester.go @@ -27,6 +27,7 @@ import ( "strings" "cuelang.org/go/cue" + "cuelang.org/go/cue/cuecontext" "cuelang.org/go/cue/load" "cuelang.org/go/cue/parser" "github.com/crossplane-contrib/function-cue/internal/fn" @@ -53,6 +54,32 @@ type Tester struct { config *TestConfig } +type assertionMode string + +const ( + AssertionModeDiff assertionMode = "diff" + AssertionModeUnification assertionMode = "unification" +) + +type ErrUnknownAssertionMode struct { + Mode string +} + +func (e *ErrUnknownAssertionMode) Error() string { + return fmt.Sprintf("unknown assertion mode: %s", e.Mode) +} + +func assertionModeFromString(mode string) (assertionMode, error) { + switch mode { + case "diff": + return AssertionModeDiff, nil + case "unification": + return AssertionModeUnification, nil + default: + return "", &ErrUnknownAssertionMode{Mode: mode} + } +} + // NewTester returns a test for the supplied configuration. It auto-discovers tags from test file names if needed. func NewTester(config TestConfig) (*Tester, error) { ret := &Tester{config: &config} @@ -102,16 +129,18 @@ func (t *Tester) discoverTags() error { return nil } -func evalPackage(pkg string, tag string, expr string, into proto.Message) (finalErr error) { +// evalPackage evaluates a CUE package with a specific tag and returns the value of the given expression. +func evalPackage(pkg string, tag string, expr string) (cue.Value, error) { iv, err := loadSingleInstanceValue(pkg, &load.Config{Tags: []string{tag}}) if err != nil { - return err + return cue.Value{}, err } + val := iv.value if expr != "" { e, err := parser.ParseExpr("expression", expr) if err != nil { - return errors.Wrap(err, "parse expression") + return val, errors.Wrap(err, "parse expression") } val = iv.value.Context().BuildExpr(e, cue.Scope(iv.value), @@ -119,9 +148,14 @@ func evalPackage(pkg string, tag string, expr string, into proto.Message) (final cue.InferBuiltins(true), ) if val.Err() != nil { - return errors.Wrap(val.Err(), "build expression") + return val, errors.Wrap(val.Err(), "build expression") } } + return val, nil +} + +// marshalValueIntoProtoMessage marshals a CUE value into a proto message. +func marshalValueIntoProtoMessage(val cue.Value, into proto.Message) error { b, err := val.MarshalJSON() if err != nil { return errors.Wrap(err, "marshal json") @@ -199,22 +233,29 @@ func (t *Tester) runTest(f *fn.Cue, codeBytes []byte, tag string) (finalErr erro } var expected fnv1.RunFunctionResponse - var err error + expectedVal, err := evalPackage(t.config.TestPackage, tag, responseVar) + if err != nil { + return errors.Wrap(err, "evaluate expected") + } if t.config.LegacyDesiredOnlyResponse { expected.Desired = &fnv1.State{} - err = evalPackage(t.config.TestPackage, tag, responseVar, expected.Desired) + if err := marshalValueIntoProtoMessage(expectedVal, expected.Desired); err != nil { + return errors.Wrap(err, "marshal expected") + } } else { - err = evalPackage(t.config.TestPackage, tag, responseVar, &expected) - } - if err != nil { - return errors.Wrap(err, "evaluate expected") + if err := marshalValueIntoProtoMessage(expectedVal, &expected); err != nil { + return errors.Wrap(err, "marshal expected") + } } var req fnv1.RunFunctionRequest - err = evalPackage(t.config.TestPackage, tag, requestVar, &req) + requestVal, err := evalPackage(t.config.TestPackage, tag, requestVar) if err != nil { return errors.Wrap(err, "evaluate request") } + if err := marshalValueIntoProtoMessage(requestVal, &req); err != nil { + return errors.Wrap(err, "marshal request") + } opts := fn.EvalOptions{ RequestVar: requestVar, @@ -227,21 +268,54 @@ func (t *Tester) runTest(f *fn.Cue, codeBytes []byte, tag string) (finalErr erro return errors.Wrap(err, "evaluate package with test request") } - expectedString, err := canonicalYAML(&expected) - if err != nil { - return errors.Wrap(err, "serialize expected") - } - actualString, err := canonicalYAML(actual) - if err != nil { - return errors.Wrap(err, "serialize actual") - } - if expectedString == actualString { - return nil + assertionMode := AssertionModeDiff + + attr := expectedVal.Attribute("assertionMode") + if attr.Err() == nil { + assertionMode, err = assertionModeFromString(attr.Contents()) + if err != nil { + return err + } } - err = printDiffs(expectedString, actualString) - if err != nil { - _, _ = fmt.Fprintln(TestOutput, "error in running diff:", err) + switch assertionMode { + case AssertionModeUnification: + // in unification mode, we check if the expected and actual values are unifiable + // by compiling a cue script that unifies the two values + + actualBytes, err := protojson.MarshalOptions{Indent: " "}.Marshal(actual) + if err != nil { + return errors.Wrap(err, "proto json marshal") + } + + assertionScript := fmt.Sprintf("expected: %s\n#Actual: %s\nunified: #Actual & expected\n", expectedVal, actualBytes) + + runtime := cuecontext.New() + assertVal := runtime.CompileString(assertionScript) + if assertVal.Err() != nil { + return errors.Wrap(assertVal.Err(), "compile cue code") + } + + if _, err := assertVal.MarshalJSON(); err != nil { + return errors.Wrap(err, "marshal cue output") + } + // script compiles and marshals, so actual and expected are unifiable. + case AssertionModeDiff: + expectedString, err := canonicalYAML(&expected) + if err != nil { + return errors.Wrap(err, "serialize expected") + } + actualString, err := canonicalYAML(actual) + if err != nil { + return errors.Wrap(err, "serialize actual") + } + if expectedString != actualString { + err = printDiffs(expectedString, actualString) + if err != nil { + _, _ = fmt.Fprintln(TestOutput, "error in running diff:", err) + } + return fmt.Errorf("expected did not match actual") + } } - return fmt.Errorf("expected did not match actual") + return nil } diff --git a/internal/cuetools/tester_test.go b/internal/cuetools/tester_test.go index 66f64d5..a537624 100644 --- a/internal/cuetools/tester_test.go +++ b/internal/cuetools/tester_test.go @@ -85,6 +85,38 @@ FAIL incorrect: expected did not match actual assert.Equal(t, strings.TrimSpace(expected), strings.TrimSpace(buf.String())) } +func TestAssertionModes(t *testing.T) { + fn := chdirCueRoot(t) + defer fn() + buf, reset := getOutput() + defer reset() + tester, err := NewTester(TestConfig{ + Package: "./runtime", + TestPackage: "./assertionmodes/tests", + }) + require.NoError(t, err) + envDiff := ExternalDiffEnvVar + diffProgram := os.Getenv(envDiff) + if diffProgram != "" { + err = os.Unsetenv(envDiff) // we expect a specific diff format + require.NoError(t, err) + defer func() { _ = os.Setenv(envDiff, diffProgram) }() + } + err = tester.Run() + expected := ` +running test tags: incorrect_unification, incorrect_unification_extra, unification +> run test "incorrect_unification" +FAIL incorrect_unification: compile cue code: unified.desired.resources.main.resource.foo: conflicting values "baz" and "bar" +> run test "incorrect_unification_extra" +FAIL incorrect_unification_extra: compile cue code: unified.desired.resources.main.resource.extra: field not allowed +> run test "unification" +PASS unification +` + require.Error(t, err) + assert.Equal(t, "2 of 3 tests had errors", err.Error()) + assert.Equal(t, strings.TrimSpace(expected), strings.TrimSpace(buf.String())) +} + func TestTesterLegacyOptions(t *testing.T) { fn := chdirCueRoot(t) defer fn()