diff --git a/generator_test.go b/generator_test.go index cdad834..27d53de 100644 --- a/generator_test.go +++ b/generator_test.go @@ -148,6 +148,57 @@ func TestEnumUnknownRoundTrips(t *testing.T) { runCmd(t, tmp, bin) } +func TestExplicitEnumWireValuesRoundTrip(t *testing.T) { + root := repoRoot(t) + tmp := t.TempDir() + + schemaPath := filepath.Join(tmp, "enum_wire_values.ridl") + schemaText := `webrpc = v1 + +name = enum_wire +version = v1.0.0 +basepath = /rpc + +enum WalletType: string + - Ethereum = "ethereum" + - SmartWallet = "smart-wallet" + +struct Wallet + - type: WalletType + +service EnumWire + - Echo(Wallet) => (Wallet) +` + if err := os.WriteFile(schemaPath, []byte(schemaText), 0o644); err != nil { + t.Fatalf("write enum wire schema: %v", err) + } + + header := filepath.Join(tmp, "enumwire.gen.h") + impl := filepath.Join(tmp, "enumwire.gen.c") + generateC(t, root, schemaPath, header, impl, "enumwire") + + headerText, err := os.ReadFile(header) + if err != nil { + t.Fatalf("read generated header: %v", err) + } + headerSrc := string(headerText) + if !strings.Contains(headerSrc, `return "ethereum";`) { + t.Fatalf("generated header should use explicit enum wire values") + } + if !strings.Contains(headerSrc, `strcmp(value, "smart-wallet") == 0`) { + t.Fatalf("generated header should decode explicit enum wire values") + } + + testMain := filepath.Join(tmp, "enum_wire_values_test_main.c") + if err := os.WriteFile(testMain, []byte(enumWireValuesTestProgram), 0o644); err != nil { + t.Fatalf("write enum wire values test program: %v", err) + } + + bin := filepath.Join(tmp, "enum-wire-values-test") + runCmd(t, tmp, "cc", "-std=c99", "-Wall", "-Wextra", "enum_wire_values_test_main.c", "-o", bin) + runCmd(t, tmp, bin) +} + func TestGenerateFailsWhenEnumUsesReservedUnknownSentinel(t *testing.T) { root := repoRoot(t) tmp := t.TempDir() @@ -522,6 +573,33 @@ int main(void) { } ` +const enumWireValuesTestProgram = `#include +#include +#include + +#include "enumwire.gen.h" + +static void fail_msg(const char *msg) { + fprintf(stderr, "%s\n", msg); + exit(1); +} + +static void expect_true(int cond, const char *msg) { + if (!cond) { + fail_msg(msg); + } +} + +int main(void) { + enumwire_wallet_type wallet_type = ENUMWIRE_WALLET_TYPE_ETHEREUM; + + expect_true(strcmp(enumwire_wallet_type_to_string(ENUMWIRE_WALLET_TYPE_ETHEREUM), "ethereum") == 0, "explicit enum wire value mismatch"); + expect_true(enumwire_wallet_type_from_string("smart-wallet", &wallet_type) == 0, "explicit enum wire value should decode"); + expect_true(wallet_type == ENUMWIRE_WALLET_TYPE_SMART_WALLET, "decoded enum value mismatch"); + return 0; +} +` + const succinctTestProgram = `#include #include #include diff --git a/types.go.tmpl b/types.go.tmpl index d53da68..cb3d6ef 100644 --- a/types.go.tmpl +++ b/types.go.tmpl @@ -30,7 +30,11 @@ static inline const char *{{ template "cTypeName" dict "Prefix" $prefix "Name" $ switch (value) { case {{ printf "%s_%s_UNKNOWN" (toUpper $prefix) (toUpper (snakeCase $type.Name)) }}: return "UNKNOWN"; {{- range $_, $field := $type.Fields }} + {{- if eq $type.Type.Expr "string" }} + case {{ template "cEnumValue" dict "Prefix" $prefix "TypeName" $type.Name "FieldName" $field.Name }}: return "{{$field.Value}}"; + {{- else }} case {{ template "cEnumValue" dict "Prefix" $prefix "TypeName" $type.Name "FieldName" $field.Name }}: return "{{$field.Name}}"; + {{- end }} {{- end }} default: return "UNKNOWN"; } @@ -43,7 +47,11 @@ static inline int {{ template "cTypeName" dict "Prefix" $prefix "Name" $type.Nam return 0; } {{- range $_, $field := $type.Fields }} + {{- if eq $type.Type.Expr "string" }} + if (strcmp(value, "{{$field.Value}}") == 0) { + {{- else }} if (strcmp(value, "{{$field.Name}}") == 0) { + {{- end }} *out = {{ template "cEnumValue" dict "Prefix" $prefix "TypeName" $type.Name "FieldName" $field.Name }}; return 0; }