From d51145f2fb12dc89582166c6168452d3518fdb62 Mon Sep 17 00:00:00 2001 From: mdemauroy Date: Tue, 17 Feb 2026 11:50:45 +0100 Subject: [PATCH] feat: add mapping mode for multi-key Vault secrets Co-Authored-By: Claude Opus 4.6 --- config.go | 15 +++- config_test.go | 81 +++++++++++++++++++ env.go | 32 ++++++++ provider/config.go | 13 ++- provider/mapping.go | 40 ++++++++++ provider/mapping_test.go | 162 ++++++++++++++++++++++++++++++++++++++ provider/provider.go | 7 ++ provider/provider_test.go | 10 +++ provider/vault.go | 49 +++++++++--- 9 files changed, 391 insertions(+), 18 deletions(-) create mode 100644 provider/mapping.go create mode 100644 provider/mapping_test.go diff --git a/config.go b/config.go index a97f667..6ffb825 100644 --- a/config.go +++ b/config.go @@ -17,10 +17,11 @@ type ProjectConfig struct { } type EnvConfig struct { - Provider string `yaml:"provider"` - Source string `yaml:"source,omitempty"` // deprecated, use Provider - PathPrefix string `yaml:"path_prefix"` - Prefix string `yaml:"prefix"` + Provider string `yaml:"provider"` + Source string `yaml:"source,omitempty"` // deprecated, use Provider + PathPrefix string `yaml:"path_prefix"` + Prefix string `yaml:"prefix"` + Mapping map[string]provider.SecretMapping `yaml:"mapping,omitempty"` } func (e EnvConfig) GetProvider() string { @@ -35,6 +36,7 @@ func (e EnvConfig) ToProviderConfig() provider.EnvConfig { Provider: e.GetProvider(), PathPrefix: e.PathPrefix, Prefix: e.Prefix, + Mapping: e.Mapping, } } @@ -102,6 +104,11 @@ func (c ProjectConfig) Validate() error { if _, ok := c.Envs[c.DefaultEnv]; !ok { return fmt.Errorf("default_env %q not found in envs", c.DefaultEnv) } + for envName, envCfg := range c.Envs { + if len(envCfg.Mapping) > 0 && envCfg.PathPrefix != "" { + return fmt.Errorf("env %q: mapping and path_prefix are mutually exclusive", envName) + } + } return nil } diff --git a/config_test.go b/config_test.go index 1fc6aab..f4a19b8 100644 --- a/config_test.go +++ b/config_test.go @@ -128,6 +128,87 @@ envs: } } +func TestLoadProjectConfigWithMapping(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, ".envmap.yaml") + + content := ` +project: testapp +default_env: dev +envs: + dev: + provider: vault + mapping: + CDN_TOKEN: + path: shared/cdn + key: CDN_TOKEN + API_KEY: + path: myapp + key: API_SECRET_KEY +` + if err := os.WriteFile(cfgPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := LoadProjectConfig(cfgPath) + if err != nil { + t.Fatalf("LoadProjectConfig: %v", err) + } + + devCfg := cfg.Envs["dev"] + if len(devCfg.Mapping) != 2 { + t.Fatalf("len(Mapping) = %d, want 2", len(devCfg.Mapping)) + } + + cdnMapping := devCfg.Mapping["CDN_TOKEN"] + if cdnMapping.Path != "shared/cdn" { + t.Errorf("CDN_TOKEN.Path = %q, want %q", cdnMapping.Path, "shared/cdn") + } + if cdnMapping.Key != "CDN_TOKEN" { + t.Errorf("CDN_TOKEN.Key = %q, want %q", cdnMapping.Key, "CDN_TOKEN") + } + + apiMapping := devCfg.Mapping["API_KEY"] + if apiMapping.Path != "myapp" { + t.Errorf("API_KEY.Path = %q, want %q", apiMapping.Path, "myapp") + } + if apiMapping.Key != "API_SECRET_KEY" { + t.Errorf("API_KEY.Key = %q, want %q", apiMapping.Key, "API_SECRET_KEY") + } + + // Verify it propagates to provider config + providerCfg := devCfg.ToProviderConfig() + if len(providerCfg.Mapping) != 2 { + t.Fatalf("provider EnvConfig.Mapping = %d, want 2", len(providerCfg.Mapping)) + } +} + +func TestLoadProjectConfigMappingAndPathPrefixConflict(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, ".envmap.yaml") + + content := ` +project: testapp +default_env: dev +envs: + dev: + provider: vault + path_prefix: /some/prefix + mapping: + SOME_VAR: + path: some/path + key: SOME_KEY +` + if err := os.WriteFile(cfgPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := LoadProjectConfig(cfgPath) + if err == nil { + t.Fatal("expected error when both mapping and path_prefix are set") + } +} + func TestLoadProjectConfigValidation(t *testing.T) { tests := []struct { name string diff --git a/env.go b/env.go index f781caf..cb85bce 100644 --- a/env.go +++ b/env.go @@ -65,6 +65,15 @@ func CollectEnvWithMetadata(ctx context.Context, projectCfg ProjectConfig, globa if err != nil { return nil, err } + + if len(envCfg.Mapping) > 0 { + mkp, ok := p.(provider.MultiKeyProvider) + if !ok { + return nil, fmt.Errorf("provider %s does not support multi-key secret reading (required by mapping)", envCfg.GetProvider()) + } + return provider.CollectMappedSecrets(ctx, mkp, envCfg.Mapping) + } + return provider.ListOrDescribe(ctx, p, provider.ResolvedPrefix(envCfg.ToProviderConfig())) } @@ -77,6 +86,23 @@ func FetchSecret(ctx context.Context, projectCfg ProjectConfig, globalCfg Global if err != nil { return "", err } + + if sm, ok := envCfg.Mapping[key]; ok { + mkp, ok := p.(provider.MultiKeyProvider) + if !ok { + return "", fmt.Errorf("provider %s does not support multi-key secret reading (required by mapping)", envCfg.GetProvider()) + } + data, err := mkp.ReadSecret(ctx, sm.Path) + if err != nil { + return "", err + } + val, ok := data[sm.Key] + if !ok { + return "", fmt.Errorf("key %q not found in secret at path %q", sm.Key, sm.Path) + } + return val, nil + } + return p.Get(ctx, provider.ApplyPrefix(envCfg.ToProviderConfig(), key)) } @@ -85,6 +111,9 @@ func WriteSecret(ctx context.Context, projectCfg ProjectConfig, globalCfg Global if !ok { return fmt.Errorf("env %q not found in project config", envName) } + if len(envCfg.Mapping) > 0 { + return fmt.Errorf("env %q uses mapping mode; secrets are read-only and managed externally", envName) + } p, err := NewProvider(envName, envCfg, globalCfg) if err != nil { return err @@ -97,6 +126,9 @@ func DeleteSecret(ctx context.Context, projectCfg ProjectConfig, globalCfg Globa if !ok { return fmt.Errorf("env %q not found in project config", envName) } + if len(envCfg.Mapping) > 0 { + return fmt.Errorf("env %q uses mapping mode; secrets are read-only and managed externally", envName) + } p, err := NewProvider(envName, envCfg, globalCfg) if err != nil { return err diff --git a/provider/config.go b/provider/config.go index d98fb91..c2b8a80 100644 --- a/provider/config.go +++ b/provider/config.go @@ -2,11 +2,18 @@ package provider import "strings" +// SecretMapping maps an env var to a specific key within a multi-key Vault secret. +type SecretMapping struct { + Path string `yaml:"path"` + Key string `yaml:"key"` +} + // EnvConfig represents the environment-specific configuration from the project file. type EnvConfig struct { - Provider string `yaml:"provider"` - PathPrefix string `yaml:"path_prefix"` - Prefix string `yaml:"prefix"` + Provider string `yaml:"provider"` + PathPrefix string `yaml:"path_prefix"` + Prefix string `yaml:"prefix"` + Mapping map[string]SecretMapping `yaml:"mapping,omitempty"` } // ProviderConfig represents the provider configuration from the global config file. diff --git a/provider/mapping.go b/provider/mapping.go new file mode 100644 index 0000000..a6229ac --- /dev/null +++ b/provider/mapping.go @@ -0,0 +1,40 @@ +package provider + +import ( + "context" + "fmt" +) + +// CollectMappedSecrets fetches secrets using the mapping configuration. +// It groups entries by path to minimize API calls, then extracts the specific +// key for each env var from the returned map. +func CollectMappedSecrets(ctx context.Context, p MultiKeyProvider, mapping map[string]SecretMapping) (map[string]SecretRecord, error) { + // Group env vars by Vault path to deduplicate reads. + type entry struct { + envVar string + key string + } + byPath := make(map[string][]entry) + for envVar, sm := range mapping { + byPath[sm.Path] = append(byPath[sm.Path], entry{envVar: envVar, key: sm.Key}) + } + + out := make(map[string]SecretRecord, len(mapping)) + + for path, entries := range byPath { + data, err := p.ReadSecret(ctx, path) + if err != nil { + return nil, fmt.Errorf("read secret at %s: %w", path, err) + } + + for _, e := range entries { + val, ok := data[e.key] + if !ok { + return nil, fmt.Errorf("key %q not found in secret at path %q (env var %s)", e.key, path, e.envVar) + } + out[e.envVar] = SecretRecord{Value: val} + } + } + + return out, nil +} diff --git a/provider/mapping_test.go b/provider/mapping_test.go new file mode 100644 index 0000000..c7fd279 --- /dev/null +++ b/provider/mapping_test.go @@ -0,0 +1,162 @@ +package provider + +import ( + "context" + "fmt" + "testing" +) + +// mockMultiKeyProvider implements MultiKeyProvider for testing. +type mockMultiKeyProvider struct { + secrets map[string]map[string]string +} + +func (m *mockMultiKeyProvider) ReadSecret(_ context.Context, path string) (map[string]string, error) { + data, ok := m.secrets[path] + if !ok { + return nil, fmt.Errorf("secret %s not found", path) + } + return data, nil +} + +func TestCollectMappedSecrets(t *testing.T) { + mock := &mockMultiKeyProvider{ + secrets: map[string]map[string]string{ + "shared/database": { + "DB_USER": "dbuser", + "DB_PASSWORD": "dbpass", + }, + "shared/cdn": { + "CDN_TOKEN": "tok123", + "CDN_HEADER": "hdr456", + }, + "myapp": { + "API_SECRET_KEY": "sk-abc", + }, + }, + } + + mapping := map[string]SecretMapping{ + "DB_USER": {Path: "shared/database", Key: "DB_USER"}, + "DB_PASSWORD": {Path: "shared/database", Key: "DB_PASSWORD"}, + "CDN_TOKEN": {Path: "shared/cdn", Key: "CDN_TOKEN"}, + "API_KEY": {Path: "myapp", Key: "API_SECRET_KEY"}, // env var differs from vault key + } + + records, err := CollectMappedSecrets(context.Background(), mock, mapping) + if err != nil { + t.Fatalf("CollectMappedSecrets: %v", err) + } + + expected := map[string]string{ + "DB_USER": "dbuser", + "DB_PASSWORD": "dbpass", + "CDN_TOKEN": "tok123", + "API_KEY": "sk-abc", + } + + if len(records) != len(expected) { + t.Fatalf("got %d records, want %d", len(records), len(expected)) + } + + for k, want := range expected { + rec, ok := records[k] + if !ok { + t.Errorf("missing key %q in results", k) + continue + } + if rec.Value != want { + t.Errorf("records[%q] value mismatch", k) + } + } +} + +func TestCollectMappedSecrets_MissingKey(t *testing.T) { + mock := &mockMultiKeyProvider{ + secrets: map[string]map[string]string{ + "shared/cdn": { + "CDN_TOKEN": "tok123", + }, + }, + } + + mapping := map[string]SecretMapping{ + "MISSING_VAR": {Path: "shared/cdn", Key: "NONEXISTENT_KEY"}, + } + + _, err := CollectMappedSecrets(context.Background(), mock, mapping) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestCollectMappedSecrets_MissingPath(t *testing.T) { + mock := &mockMultiKeyProvider{ + secrets: map[string]map[string]string{}, + } + + mapping := map[string]SecretMapping{ + "SOME_VAR": {Path: "nonexistent/path", Key: "SOME_KEY"}, + } + + _, err := CollectMappedSecrets(context.Background(), mock, mapping) + if err == nil { + t.Fatal("expected error for missing path") + } +} + +func TestCollectMappedSecrets_Empty(t *testing.T) { + mock := &mockMultiKeyProvider{ + secrets: map[string]map[string]string{}, + } + + records, err := CollectMappedSecrets(context.Background(), mock, map[string]SecretMapping{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(records) != 0 { + t.Fatalf("expected empty map, got %d entries", len(records)) + } +} + +func TestCollectMappedSecrets_DeduplicatesPaths(t *testing.T) { + callCount := 0 + mock := &countingMultiKeyProvider{ + secrets: map[string]map[string]string{ + "shared/database": { + "USER": "u", + "PASS": "p", + }, + }, + callCount: &callCount, + } + + mapping := map[string]SecretMapping{ + "DB_USER": {Path: "shared/database", Key: "USER"}, + "DB_PASS": {Path: "shared/database", Key: "PASS"}, + } + + _, err := CollectMappedSecrets(context.Background(), mock, mapping) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if callCount != 1 { + t.Errorf("ReadSecret called %d times, want 1 (should deduplicate by path)", callCount) + } +} + +// countingMultiKeyProvider tracks how many times ReadSecret is called. +type countingMultiKeyProvider struct { + secrets map[string]map[string]string + callCount *int +} + +func (m *countingMultiKeyProvider) ReadSecret(_ context.Context, path string) (map[string]string, error) { + *m.callCount++ + data, ok := m.secrets[path] + if !ok { + return nil, fmt.Errorf("secret %s not found", path) + } + return data, nil +} diff --git a/provider/provider.go b/provider/provider.go index 33cbfdd..cfa97b9 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -28,6 +28,13 @@ type Provider interface { Set(ctx context.Context, name, value string) error } +// MultiKeyProvider can read all key-value pairs from a single secret path. +// Providers that store multiple keys per secret (e.g., Vault KV) implement this +// to support the mapping configuration. +type MultiKeyProvider interface { + ReadSecret(ctx context.Context, path string) (map[string]string, error) +} + // Factory creates a Provider from configuration. type Factory func(envCfg EnvConfig, providerCfg ProviderConfig) (Provider, error) diff --git a/provider/provider_test.go b/provider/provider_test.go index bb44dd5..4fbc603 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -136,6 +136,16 @@ func TestApplyTrimRoundtrip(t *testing.T) { } } +func TestVaultImplementsMultiKeyProvider(t *testing.T) { + // Verify at compile time that vaultProvider implements MultiKeyProvider. + // We can't construct a real vault client without a server, but we can + // verify the interface is satisfied via a type assertion on nil. + var p interface{} = (*vaultProvider)(nil) + if _, ok := p.(MultiKeyProvider); !ok { + t.Error("vaultProvider does not implement MultiKeyProvider") + } +} + // mockProvider for testing factory registration pattern type mockProvider struct{} diff --git a/provider/vault.go b/provider/vault.go index c8641ba..84bfe2a 100644 --- a/provider/vault.go +++ b/provider/vault.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "path" "strings" vault "github.com/hashicorp/vault/api" @@ -71,33 +72,33 @@ func newVault(envCfg EnvConfig, providerCfg ProviderConfig) (Provider, error) { func (p *vaultProvider) secretPath(name string) string { prefixed := ApplyPrefix(p.envCfg, name) - return fmt.Sprintf("%s/data/%s", p.mount, prefixed) + return path.Join(p.mount, "data", prefixed) } func (p *vaultProvider) Get(ctx context.Context, name string) (string, error) { - path := p.secretPath(name) - secret, err := p.client.Logical().ReadWithContext(ctx, path) + sPath := p.secretPath(name) + secret, err := p.client.Logical().ReadWithContext(ctx, sPath) if err != nil { - return "", fmt.Errorf("vault get %s: %w", path, err) + return "", fmt.Errorf("vault get %s: %w", sPath, err) } if secret == nil || secret.Data == nil { - return "", fmt.Errorf("secret %s not found in vault", path) + return "", fmt.Errorf("secret %s not found in vault", sPath) } data, ok := secret.Data["data"].(map[string]interface{}) if !ok { - return "", fmt.Errorf("vault secret %s has unexpected format", path) + return "", fmt.Errorf("vault secret %s has unexpected format", sPath) } value, ok := data["value"].(string) if !ok { - return "", fmt.Errorf("vault secret %s missing 'value' field", path) + return "", fmt.Errorf("vault secret %s missing 'value' field", sPath) } return value, nil } func (p *vaultProvider) List(ctx context.Context, prefix string) (map[string]string, error) { - listPath := fmt.Sprintf("%s/metadata/%s", p.mount, ensurePrefixSlash(prefix)) + listPath := path.Join(p.mount, "metadata", ensurePrefixSlash(prefix)) secret, err := p.client.Logical().ListWithContext(ctx, listPath) if err != nil { return nil, fmt.Errorf("vault list %s: %w", listPath, err) @@ -129,16 +130,42 @@ func (p *vaultProvider) List(ctx context.Context, prefix string) (map[string]str return out, nil } +func (p *vaultProvider) ReadSecret(ctx context.Context, secretPath string) (map[string]string, error) { + fullPath := path.Join(p.mount, "data", secretPath) + secret, err := p.client.Logical().ReadWithContext(ctx, fullPath) + if err != nil { + return nil, fmt.Errorf("vault read %s: %w", fullPath, err) + } + if secret == nil || secret.Data == nil { + return nil, fmt.Errorf("secret %s not found in vault", fullPath) + } + + data, ok := secret.Data["data"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("vault secret %s has unexpected format", fullPath) + } + + out := make(map[string]string, len(data)) + for k, v := range data { + if s, ok := v.(string); ok { + out[k] = s + } else { + out[k] = fmt.Sprintf("%v", v) + } + } + return out, nil +} + func (p *vaultProvider) Set(ctx context.Context, name, value string) error { - path := p.secretPath(name) + sPath := p.secretPath(name) data := map[string]interface{}{ "data": map[string]interface{}{ "value": value, }, } - _, err := p.client.Logical().WriteWithContext(ctx, path, data) + _, err := p.client.Logical().WriteWithContext(ctx, sPath, data) if err != nil { - return fmt.Errorf("vault put %s: %w", path, err) + return fmt.Errorf("vault put %s: %w", sPath, err) } return nil }