diff --git a/internal/cmd/flags.go b/internal/cmd/flags.go index afd97f79..f0209b1d 100644 --- a/internal/cmd/flags.go +++ b/internal/cmd/flags.go @@ -17,41 +17,21 @@ // }) // } // -// # getDefault*() Flag Helper Pattern +// # Flag Defaults with Environment Variable Overrides // -// Each flag whose default value can be overridden by an environment variable has a -// corresponding getDefault*() helper function that follows this pattern: +// Flags whose defaults can be overridden by an environment variable use inline +// envutil.GetEnv* calls directly in the RegisterFlag block: // -// func getDefaultXxx() T { -// return envutil.GetEnvT("MCP_GATEWAY_XXX", defaultXxx) -// } -// -// Current helpers and their environment variables: +// cmd.Flags().StringVar(&myDir, "my-dir", envutil.GetEnvString("MY_DIR_ENV", config.DefaultMyDir), "...") // -// flags_logging.go getDefaultLogDir() → MCP_GATEWAY_LOG_DIR -// flags_logging.go getDefaultPayloadDir() → MCP_GATEWAY_PAYLOAD_DIR -// flags_logging.go getDefaultPayloadPathPrefix() → MCP_GATEWAY_PAYLOAD_PATH_PREFIX -// flags_logging.go getDefaultPayloadSizeThreshold() → MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD -// flags_difc.go getDefaultDIFCMode() → MCP_GATEWAY_GUARDS_MODE -// flags_difc.go getDefaultDIFCSinkServerIDs() → MCP_GATEWAY_GUARDS_SINK_SERVER_IDS -// flags_difc.go getDefaultGuardPolicyJSON() → MCP_GATEWAY_GUARD_POLICY_JSON -// flags_difc.go getDefaultAllowOnlyScopePublic() → MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC -// flags_difc.go getDefaultAllowOnlyOwner() → MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER -// flags_difc.go getDefaultAllowOnlyRepo() → MCP_GATEWAY_ALLOWONLY_SCOPE_REPO -// flags_difc.go getDefaultAllowOnlyMinIntegrity() → MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY -// flags_tracing.go getDefaultOTLPEndpoint() → OTEL_EXPORTER_OTLP_ENDPOINT -// flags_tracing.go getDefaultOTLPServiceName() → OTEL_SERVICE_NAME +// This keeps the env-var name co-located with the flag declaration. // -// This pattern is intentionally kept in individual feature files because: -// - Each helper names the specific environment variable it reads, making the -// coupling between flag and env var explicit and discoverable. -// - The one-liner wrappers are trivial and unlikely to diverge. -// - Go's type system (string/int/bool) prevents a single generic helper without -// sacrificing readability. +// Exception: getDefaultDIFCMode() in flags_difc.go is kept as a named helper +// because it contains validation logic beyond a simple env lookup. // // When adding a new flag with an environment variable override: -// 1. Add a defaultXxx constant and a getDefaultXxx() function in the feature file. -// 2. Add the new helper to the table above. +// 1. Use envutil.GetEnv* directly in the RegisterFlag call. +// 2. Document the environment variable in AGENTS.md and README.md. package cmd import "github.com/spf13/cobra" diff --git a/internal/cmd/flags_difc.go b/internal/cmd/flags_difc.go index 99c696fc..26c03ed7 100644 --- a/internal/cmd/flags_difc.go +++ b/internal/cmd/flags_difc.go @@ -13,11 +13,6 @@ import ( "github.com/spf13/cobra" ) -// DIFC flag defaults -const ( - defaultAllowOnlyMinIntegrity = "" -) - // DIFC flag variables var ( difcMode string @@ -32,12 +27,12 @@ var ( func init() { RegisterFlag(func(cmd *cobra.Command) { cmd.Flags().StringVar(&difcMode, "guards-mode", getDefaultDIFCMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)") - cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", getDefaultDIFCSinkServerIDs(), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots") - cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", getDefaultGuardPolicyJSON(), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})") - cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", getDefaultAllowOnlyScopePublic(), "Use public AllowOnly scope") - cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", getDefaultAllowOnlyOwner(), "AllowOnly owner scope value") - cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", getDefaultAllowOnlyRepo(), "AllowOnly repo name (requires owner)") - cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", getDefaultAllowOnlyMinIntegrity(), "AllowOnly integrity: none|unapproved|approved|merged") + cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", ""), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots") + cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", envutil.GetEnvString("MCP_GATEWAY_GUARD_POLICY_JSON", ""), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})") + cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false), "Use public AllowOnly scope") + cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", ""), "AllowOnly owner scope value") + cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", ""), "AllowOnly repo name (requires owner)") + cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", ""), "AllowOnly integrity: none|unapproved|approved|merged") }) } @@ -55,40 +50,6 @@ func getDefaultDIFCMode() string { return difc.ModeStrict } -func getDefaultAllowOnlyScopePublic() bool { - return envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false) -} - -// getDefaultDIFCSinkServerIDs returns the default DIFC sink server IDs string, -// checking MCP_GATEWAY_GUARDS_SINK_SERVER_IDS environment variable. -func getDefaultDIFCSinkServerIDs() string { - return envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "") -} - -// getDefaultGuardPolicyJSON returns the default guard policy JSON string, -// checking MCP_GATEWAY_GUARD_POLICY_JSON environment variable. -func getDefaultGuardPolicyJSON() string { - return envutil.GetEnvString("MCP_GATEWAY_GUARD_POLICY_JSON", "") -} - -// getDefaultAllowOnlyOwner returns the default AllowOnly owner scope, -// checking MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER environment variable. -func getDefaultAllowOnlyOwner() string { - return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "") -} - -// getDefaultAllowOnlyRepo returns the default AllowOnly repo name, -// checking MCP_GATEWAY_ALLOWONLY_SCOPE_REPO environment variable. -func getDefaultAllowOnlyRepo() string { - return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "") -} - -// getDefaultAllowOnlyMinIntegrity returns the default AllowOnly minimum integrity level, -// checking MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY environment variable. -func getDefaultAllowOnlyMinIntegrity() string { - return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", defaultAllowOnlyMinIntegrity) -} - func parseDIFCSinkServerIDs(input string) ([]string, error) { if strings.TrimSpace(input) == "" { return nil, nil diff --git a/internal/cmd/flags_difc_test.go b/internal/cmd/flags_difc_test.go index d4684c5c..efb8c14a 100644 --- a/internal/cmd/flags_difc_test.go +++ b/internal/cmd/flags_difc_test.go @@ -147,46 +147,6 @@ func TestValidDIFCModes(t *testing.T) { require.Len(difc.ValidModes, 3, "should only have 3 valid modes") } -func TestGetDefaultDIFCSinkServerIDs(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns empty string", - setEnv: false, - expected: "", - }, - { - name: "env var set - returns value", - envValue: "safeoutputs,github", - setEnv: true, - expected: "safeoutputs,github", - }, - { - name: "empty env var - returns empty string", - envValue: "", - setEnv: true, - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", tt.envValue) - } else { - t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "") - } - - result := getDefaultDIFCSinkServerIDs() - assert.Equal(t, tt.expected, result) - }) - } -} - func TestParseDIFCSinkServerIDs(t *testing.T) { tests := []struct { name string @@ -270,17 +230,3 @@ func TestBuildAllowOnlyPolicy(t *testing.T) { require.Error(t, err) }) } - -func TestGetDefaultGuardPolicyInputs(t *testing.T) { - t.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", `{"allow-only":{"repos":"public","min-integrity":"none"}}`) - t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", "1") - t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "lpcox") - t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "gh-aw-mcpg") - t.Setenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", "unapproved") - - assert.NotEmpty(t, getDefaultGuardPolicyJSON()) - assert.True(t, getDefaultAllowOnlyScopePublic()) - assert.Equal(t, "lpcox", getDefaultAllowOnlyOwner()) - assert.Equal(t, "gh-aw-mcpg", getDefaultAllowOnlyRepo()) - assert.Equal(t, "unapproved", getDefaultAllowOnlyMinIntegrity()) -} diff --git a/internal/cmd/flags_logging.go b/internal/cmd/flags_logging.go index dc767bdf..3e41485f 100644 --- a/internal/cmd/flags_logging.go +++ b/internal/cmd/flags_logging.go @@ -8,11 +8,6 @@ import ( "github.com/spf13/cobra" ) -// Logging flag defaults -const ( - defaultPayloadPathPrefix = "" // Empty by default - use actual filesystem path -) - // Logging flag variables var ( logDir string @@ -23,33 +18,9 @@ var ( func init() { RegisterFlag(func(cmd *cobra.Command) { - cmd.Flags().StringVar(&logDir, "log-dir", getDefaultLogDir(), "Directory for log files (falls back to stdout if directory cannot be created)") - cmd.Flags().StringVar(&payloadDir, "payload-dir", getDefaultPayloadDir(), "Directory for storing large payload files (segmented by session ID)") - cmd.Flags().StringVar(&payloadPathPrefix, "payload-path-prefix", getDefaultPayloadPathPrefix(), "Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)") - cmd.Flags().IntVar(&payloadSizeThreshold, "payload-size-threshold", getDefaultPayloadSizeThreshold(), "Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored, smaller ones returned inline") + cmd.Flags().StringVar(&logDir, "log-dir", envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir), "Directory for log files (falls back to stdout if directory cannot be created)") + cmd.Flags().StringVar(&payloadDir, "payload-dir", envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_DIR", config.DefaultPayloadDir), "Directory for storing large payload files (segmented by session ID)") + cmd.Flags().StringVar(&payloadPathPrefix, "payload-path-prefix", envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", ""), "Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)") + cmd.Flags().IntVar(&payloadSizeThreshold, "payload-size-threshold", envutil.GetEnvInt("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", config.DefaultPayloadSizeThreshold), "Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored, smaller ones returned inline") }) } - -// getDefaultLogDir returns the default log directory, checking MCP_GATEWAY_LOG_DIR -// environment variable first, then falling back to the hardcoded default -func getDefaultLogDir() string { - return envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir) -} - -// getDefaultPayloadDir returns the default payload directory, checking MCP_GATEWAY_PAYLOAD_DIR -// environment variable first, then falling back to the hardcoded default -func getDefaultPayloadDir() string { - return envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_DIR", config.DefaultPayloadDir) -} - -// getDefaultPayloadPathPrefix returns the default payload path prefix, checking MCP_GATEWAY_PAYLOAD_PATH_PREFIX -// environment variable first, then falling back to the hardcoded default -func getDefaultPayloadPathPrefix() string { - return envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", defaultPayloadPathPrefix) -} - -// getDefaultPayloadSizeThreshold returns the default payload size threshold, checking -// MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD environment variable first, then falling back to the hardcoded default -func getDefaultPayloadSizeThreshold() int { - return envutil.GetEnvInt("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", config.DefaultPayloadSizeThreshold) -} diff --git a/internal/cmd/flags_logging_test.go b/internal/cmd/flags_logging_test.go deleted file mode 100644 index 68210a67..00000000 --- a/internal/cmd/flags_logging_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package cmd - -import ( - "testing" - - "github.com/github/gh-aw-mcpg/internal/config" - "github.com/stretchr/testify/assert" -) - -func TestGetDefaultLogDir(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns default", - setEnv: false, - expected: config.DefaultLogDir, - }, - { - name: "env var set - returns custom path", - envValue: "/custom/log/dir", - setEnv: true, - expected: "/custom/log/dir", - }, - { - name: "empty env var - returns default", - envValue: "", - setEnv: true, - expected: config.DefaultLogDir, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("MCP_GATEWAY_LOG_DIR", tt.envValue) - } else { - t.Setenv("MCP_GATEWAY_LOG_DIR", "") - } - - result := getDefaultLogDir() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestGetDefaultPayloadDir(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns default", - setEnv: false, - expected: config.DefaultPayloadDir, - }, - { - name: "env var set - returns custom path", - envValue: "/custom/payload/dir", - setEnv: true, - expected: "/custom/payload/dir", - }, - { - name: "empty env var - returns default", - envValue: "", - setEnv: true, - expected: config.DefaultPayloadDir, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("MCP_GATEWAY_PAYLOAD_DIR", tt.envValue) - } else { - t.Setenv("MCP_GATEWAY_PAYLOAD_DIR", "") - } - - result := getDefaultPayloadDir() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestGetDefaultPayloadSizeThreshold(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected int - }{ - { - name: "no env var - returns default", - setEnv: false, - expected: config.DefaultPayloadSizeThreshold, - }, - { - name: "valid env var", - envValue: "2048", - setEnv: true, - expected: 2048, - }, - { - name: "very large threshold", - envValue: "10240", - setEnv: true, - expected: 10240, - }, - { - name: "small threshold", - envValue: "512", - setEnv: true, - expected: 512, - }, - { - name: "invalid value - non-numeric", - envValue: "invalid", - setEnv: true, - expected: config.DefaultPayloadSizeThreshold, - }, - { - name: "invalid value - negative", - envValue: "-100", - setEnv: true, - expected: config.DefaultPayloadSizeThreshold, - }, - { - name: "invalid value - zero", - envValue: "0", - setEnv: true, - expected: config.DefaultPayloadSizeThreshold, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", tt.envValue) - } else { - t.Setenv("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", "") - } - - result := getDefaultPayloadSizeThreshold() - assert.Equal(t, tt.expected, result, "Threshold should match expected value") - }) - } -} - -func TestPayloadSizeThresholdFlagDefault(t *testing.T) { - t.Setenv("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", "") - - result := getDefaultPayloadSizeThreshold() - assert.Equal(t, 524288, result, "Default should be 524288 bytes (512KB)") -} - -func TestPayloadSizeThresholdEnvVar(t *testing.T) { - t.Setenv("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", "4096") - - result := getDefaultPayloadSizeThreshold() - assert.Equal(t, 4096, result, "Environment variable should override default") -} - -func TestGetDefaultPayloadPathPrefix(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns default", - setEnv: false, - expected: defaultPayloadPathPrefix, - }, - { - name: "env var set - returns custom path", - envValue: "/workspace/payloads", - setEnv: true, - expected: "/workspace/payloads", - }, - { - name: "empty env var - returns default", - envValue: "", - setEnv: true, - expected: defaultPayloadPathPrefix, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", tt.envValue) - } else { - t.Setenv("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", "") - } - - result := getDefaultPayloadPathPrefix() - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/internal/cmd/flags_tracing.go b/internal/cmd/flags_tracing.go index 69270fbe..bd90b8db 100644 --- a/internal/cmd/flags_tracing.go +++ b/internal/cmd/flags_tracing.go @@ -17,23 +17,11 @@ var ( func init() { RegisterFlag(func(cmd *cobra.Command) { - cmd.Flags().StringVar(&otlpEndpoint, "otlp-endpoint", getDefaultOTLPEndpoint(), + cmd.Flags().StringVar(&otlpEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""), "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Defaults from OTEL_EXPORTER_OTLP_ENDPOINT when set. Tracing is disabled when empty.") - cmd.Flags().StringVar(&otlpServiceName, "otlp-service-name", getDefaultOTLPServiceName(), + cmd.Flags().StringVar(&otlpServiceName, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName), "Service name reported in traces. Defaults from OTEL_SERVICE_NAME when set.") cmd.Flags().Float64Var(&otlpSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate, "Fraction of traces to sample and export (0.0–1.0). Default 1.0 samples everything.") }) } - -// getDefaultOTLPEndpoint returns the OTLP endpoint, checking OTEL_EXPORTER_OTLP_ENDPOINT -// environment variable first, then falling back to empty (disabled). -func getDefaultOTLPEndpoint() string { - return envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", "") -} - -// getDefaultOTLPServiceName returns the OTLP service name, checking OTEL_SERVICE_NAME -// environment variable first, then falling back to the default. -func getDefaultOTLPServiceName() string { - return envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName) -} diff --git a/internal/cmd/flags_tracing_test.go b/internal/cmd/flags_tracing_test.go index 03253088..a504dc5f 100644 --- a/internal/cmd/flags_tracing_test.go +++ b/internal/cmd/flags_tracing_test.go @@ -7,111 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestGetDefaultOTLPEndpoint(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns empty string (tracing disabled)", - setEnv: false, - expected: "", - }, - { - name: "env var set to HTTP endpoint", - envValue: "http://localhost:4318", - setEnv: true, - expected: "http://localhost:4318", - }, - { - name: "env var set to HTTPS endpoint", - envValue: "https://otel.example.com:4318", - setEnv: true, - expected: "https://otel.example.com:4318", - }, - { - name: "empty env var - returns empty string", - envValue: "", - setEnv: true, - expected: "", - }, - { - name: "env var with trailing slash", - envValue: "http://localhost:4318/", - setEnv: true, - expected: "http://localhost:4318/", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", tt.envValue) - } else { - t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") - } - - result := getDefaultOTLPEndpoint() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestGetDefaultOTLPServiceName(t *testing.T) { - tests := []struct { - name string - envValue string - setEnv bool - expected string - }{ - { - name: "no env var - returns default service name", - setEnv: false, - expected: config.DefaultTracingServiceName, - }, - { - name: "env var set to custom service name", - envValue: "my-custom-service", - setEnv: true, - expected: "my-custom-service", - }, - { - name: "env var set to explicit mcp-gateway", - envValue: "mcp-gateway", - setEnv: true, - expected: "mcp-gateway", - }, - { - name: "empty env var - returns default", - envValue: "", - setEnv: true, - expected: config.DefaultTracingServiceName, - }, - { - name: "env var with spaces in service name", - envValue: "my service", - setEnv: true, - expected: "my service", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setEnv { - t.Setenv("OTEL_SERVICE_NAME", tt.envValue) - } else { - t.Setenv("OTEL_SERVICE_NAME", "") - } - - result := getDefaultOTLPServiceName() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestGetDefaultOTLPServiceName_DefaultIsCorrect(t *testing.T) { +func TestDefaultTracingServiceName_IsCorrect(t *testing.T) { // Verify the default constant value hasn't changed unexpectedly. // "mcp-gateway" is the canonical service name used in OTLP traces. assert.Equal(t, "mcp-gateway", config.DefaultTracingServiceName, diff --git a/internal/cmd/proxy.go b/internal/cmd/proxy.go index 1e11c2a6..354074a8 100644 --- a/internal/cmd/proxy.go +++ b/internal/cmd/proxy.go @@ -105,16 +105,16 @@ Local usage: cmd.Flags().StringVar(&proxyPolicy, "policy", os.Getenv("MCP_GATEWAY_GUARD_POLICY_JSON"), "Guard policy JSON") cmd.Flags().StringVar(&proxyToken, "github-token", "", "Fallback GitHub API token (default: forwards client Authorization header)") cmd.Flags().StringVarP(&proxyListen, "listen", "l", "127.0.0.1:8080", "Proxy listen address") - cmd.Flags().StringVar(&proxyLogDir, "log-dir", getDefaultLogDir(), "Log file directory") + cmd.Flags().StringVar(&proxyLogDir, "log-dir", envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir), "Log file directory") cmd.Flags().StringVar(&proxyDIFCMode, "guards-mode", "filter", "DIFC enforcement mode: strict, filter, propagate") cmd.Flags().StringVar(&proxyAPIURL, "github-api-url", "", "Upstream GitHub API URL (default: auto-derived from GITHUB_API_URL or GITHUB_SERVER_URL, falls back to https://api.github.com)") cmd.Flags().BoolVar(&proxyTLS, "tls", false, "Enable HTTPS with auto-generated self-signed certificates") cmd.Flags().StringVar(&proxyTLSDir, "tls-dir", "", "Directory for TLS certificates (default: /proxy-tls)") cmd.Flags().StringSliceVar(&proxyTrustedBots, "trusted-bots", nil, "Additional trusted bot usernames (comma-separated, extends built-in list)") cmd.Flags().StringSliceVar(&proxyTrustedUsers, "trusted-users", nil, "User logins that receive approved integrity (comma-separated)") - cmd.Flags().StringVar(&proxyOTLPEndpoint, "otlp-endpoint", getDefaultOTLPEndpoint(), + cmd.Flags().StringVar(&proxyOTLPEndpoint, "otlp-endpoint", envutil.GetEnvString("OTEL_EXPORTER_OTLP_ENDPOINT", ""), "OTLP HTTP endpoint for trace export (e.g. http://localhost:4318). Tracing is disabled when empty.") - cmd.Flags().StringVar(&proxyOTLPService, "otlp-service-name", getDefaultOTLPServiceName(), + cmd.Flags().StringVar(&proxyOTLPService, "otlp-service-name", envutil.GetEnvString("OTEL_SERVICE_NAME", config.DefaultTracingServiceName), "Service name reported in traces.") cmd.Flags().Float64Var(&proxyOTLPSampleRate, "otlp-sample-rate", config.DefaultTracingSampleRate, "Fraction of traces to sample and export (0.0–1.0).") diff --git a/internal/cmd/proxy_test.go b/internal/cmd/proxy_test.go index 4e6a4dbf..7a837941 100644 --- a/internal/cmd/proxy_test.go +++ b/internal/cmd/proxy_test.go @@ -324,7 +324,7 @@ func TestNewProxyCmd_TrustedBotsAndUsersDefaultNil(t *testing.T) { assert.Empty(t, users, "--trusted-users should default to empty/nil") } -// TestNewProxyCmd_LogDirDefault verifies --log-dir uses getDefaultLogDir() as default. +// TestNewProxyCmd_LogDirDefault verifies --log-dir uses the default log directory. func TestNewProxyCmd_LogDirDefault(t *testing.T) { t.Setenv("MCP_GATEWAY_LOG_DIR", "") @@ -333,8 +333,8 @@ func TestNewProxyCmd_LogDirDefault(t *testing.T) { val, err := cmd.Flags().GetString("log-dir") require.NoError(t, err) - assert.Equal(t, getDefaultLogDir(), val, - "--log-dir should use getDefaultLogDir() as its default value") + assert.Equal(t, config.DefaultLogDir, val, + "--log-dir should default to config.DefaultLogDir when MCP_GATEWAY_LOG_DIR is unset") } // TestNewProxyCmd_ListenFlag verifies --listen, -l shorthand and default value. diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 51f9101a..96557761 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -18,6 +18,7 @@ import ( "github.com/github/gh-aw-mcpg/internal/auth" "github.com/github/gh-aw-mcpg/internal/config" "github.com/github/gh-aw-mcpg/internal/difc" + "github.com/github/gh-aw-mcpg/internal/envutil" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/logger/sanitize" "github.com/github/gh-aw-mcpg/internal/server" @@ -241,11 +242,11 @@ func run(cmd *cobra.Command, args []string) error { cfg.Gateway.PayloadDir = payloadDir } - // Apply payload path prefix flag (if different from default, it was explicitly set) + // Apply payload path prefix: CLI flag takes priority, then env-derived non-empty value. if cmd.Flags().Changed("payload-path-prefix") { cfg.Gateway.PayloadPathPrefix = payloadPathPrefix - } else if payloadPathPrefix != "" && payloadPathPrefix != defaultPayloadPathPrefix { - // Environment variable was set + } else if payloadPathPrefix != "" { + // envutil.GetEnvString returned a non-empty value from MCP_GATEWAY_PAYLOAD_PATH_PREFIX cfg.Gateway.PayloadPathPrefix = payloadPathPrefix } @@ -447,7 +448,7 @@ func resolveGuardPolicyOverride(cmd *cobra.Command) (*config.GuardPolicy, string if hasScopePublic || hasScopeOwner || hasScopeRepo || hasMinIntegrity { policy, err := config.BuildAllowOnlyPolicy( - getDefaultAllowOnlyScopePublic(), + envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false), os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER"), os.Getenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO"), os.Getenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY"),