Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions internal/cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,21 @@
// })
// }
//
// # getDefault*() Flag Helper Pattern
// # Flag Defaults with Environment Variable Overrides
//
// Each flag whose default value can be overridden by an environment variable has a
// corresponding getDefault*() helper function that follows this pattern:
// Flags whose defaults can be overridden by an environment variable use inline
// envutil.GetEnv* calls directly in the RegisterFlag block:
//
// func getDefaultXxx() T {
// return envutil.GetEnvT("MCP_GATEWAY_XXX", defaultXxx)
// }
//
// Current helpers and their environment variables:
// cmd.Flags().StringVar(&myDir, "my-dir", envutil.GetEnvString("MY_DIR_ENV", config.DefaultMyDir), "...")
//
// flags_logging.go getDefaultLogDir() → MCP_GATEWAY_LOG_DIR
// flags_logging.go getDefaultPayloadDir() → MCP_GATEWAY_PAYLOAD_DIR
// flags_logging.go getDefaultPayloadPathPrefix() → MCP_GATEWAY_PAYLOAD_PATH_PREFIX
// flags_logging.go getDefaultPayloadSizeThreshold() → MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD
// flags_difc.go getDefaultDIFCMode() → MCP_GATEWAY_GUARDS_MODE
// flags_difc.go getDefaultDIFCSinkServerIDs() → MCP_GATEWAY_GUARDS_SINK_SERVER_IDS
// flags_difc.go getDefaultGuardPolicyJSON() → MCP_GATEWAY_GUARD_POLICY_JSON
// flags_difc.go getDefaultAllowOnlyScopePublic() → MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC
// flags_difc.go getDefaultAllowOnlyOwner() → MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER
// flags_difc.go getDefaultAllowOnlyRepo() → MCP_GATEWAY_ALLOWONLY_SCOPE_REPO
// flags_difc.go getDefaultAllowOnlyMinIntegrity() → MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY
// flags_tracing.go getDefaultOTLPEndpoint() → OTEL_EXPORTER_OTLP_ENDPOINT
// flags_tracing.go getDefaultOTLPServiceName() → OTEL_SERVICE_NAME
// This keeps the env-var name co-located with the flag declaration.
//
// This pattern is intentionally kept in individual feature files because:
// - Each helper names the specific environment variable it reads, making the
// coupling between flag and env var explicit and discoverable.
// - The one-liner wrappers are trivial and unlikely to diverge.
// - Go's type system (string/int/bool) prevents a single generic helper without
// sacrificing readability.
// Exception: getDefaultDIFCMode() in flags_difc.go is kept as a named helper
// because it contains validation logic beyond a simple env lookup.
//
// When adding a new flag with an environment variable override:
// 1. Add a defaultXxx constant and a getDefaultXxx() function in the feature file.
// 2. Add the new helper to the table above.
// 1. Use envutil.GetEnv* directly in the RegisterFlag call.
// 2. Document the environment variable in AGENTS.md and README.md.
package cmd

import "github.com/spf13/cobra"
Expand Down
51 changes: 6 additions & 45 deletions internal/cmd/flags_difc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ import (
"github.com/spf13/cobra"
)

// DIFC flag defaults
const (
defaultAllowOnlyMinIntegrity = ""
)

// DIFC flag variables
var (
difcMode string
Expand All @@ -32,12 +27,12 @@ var (
func init() {
RegisterFlag(func(cmd *cobra.Command) {
cmd.Flags().StringVar(&difcMode, "guards-mode", getDefaultDIFCMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)")
cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", getDefaultDIFCSinkServerIDs(), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots")
cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", getDefaultGuardPolicyJSON(), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})")
cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", getDefaultAllowOnlyScopePublic(), "Use public AllowOnly scope")
cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", getDefaultAllowOnlyOwner(), "AllowOnly owner scope value")
cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", getDefaultAllowOnlyRepo(), "AllowOnly repo name (requires owner)")
cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", getDefaultAllowOnlyMinIntegrity(), "AllowOnly integrity: none|unapproved|approved|merged")
cmd.Flags().StringVar(&difcSinkServerIDs, "guards-sink-server-ids", envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", ""), "Comma-separated server IDs whose RPC JSONL logs should include agent secrecy/integrity tag snapshots")
cmd.Flags().StringVar(&guardPolicyJSON, "guard-policy-json", envutil.GetEnvString("MCP_GATEWAY_GUARD_POLICY_JSON", ""), "Guard policy JSON (e.g. {\"allow-only\":{\"repos\":\"public\",\"min-integrity\":\"none\"}})")
cmd.Flags().BoolVar(&allowOnlyPublic, "allowonly-scope-public", envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false), "Use public AllowOnly scope")
cmd.Flags().StringVar(&allowOnlyOwner, "allowonly-scope-owner", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", ""), "AllowOnly owner scope value")
cmd.Flags().StringVar(&allowOnlyRepo, "allowonly-scope-repo", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", ""), "AllowOnly repo name (requires owner)")
cmd.Flags().StringVar(&allowOnlyMinInt, "allowonly-min-integrity", envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", ""), "AllowOnly integrity: none|unapproved|approved|merged")
})
}

Expand All @@ -55,40 +50,6 @@ func getDefaultDIFCMode() string {
return difc.ModeStrict
}

func getDefaultAllowOnlyScopePublic() bool {
return envutil.GetEnvBool("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", false)
}

// getDefaultDIFCSinkServerIDs returns the default DIFC sink server IDs string,
// checking MCP_GATEWAY_GUARDS_SINK_SERVER_IDS environment variable.
func getDefaultDIFCSinkServerIDs() string {
return envutil.GetEnvString("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "")
}

// getDefaultGuardPolicyJSON returns the default guard policy JSON string,
// checking MCP_GATEWAY_GUARD_POLICY_JSON environment variable.
func getDefaultGuardPolicyJSON() string {
return envutil.GetEnvString("MCP_GATEWAY_GUARD_POLICY_JSON", "")
}

// getDefaultAllowOnlyOwner returns the default AllowOnly owner scope,
// checking MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER environment variable.
func getDefaultAllowOnlyOwner() string {
return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "")
}

// getDefaultAllowOnlyRepo returns the default AllowOnly repo name,
// checking MCP_GATEWAY_ALLOWONLY_SCOPE_REPO environment variable.
func getDefaultAllowOnlyRepo() string {
return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "")
}

// getDefaultAllowOnlyMinIntegrity returns the default AllowOnly minimum integrity level,
// checking MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY environment variable.
func getDefaultAllowOnlyMinIntegrity() string {
return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", defaultAllowOnlyMinIntegrity)
}

func parseDIFCSinkServerIDs(input string) ([]string, error) {
if strings.TrimSpace(input) == "" {
return nil, nil
Expand Down
54 changes: 0 additions & 54 deletions internal/cmd/flags_difc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,46 +147,6 @@ func TestValidDIFCModes(t *testing.T) {
require.Len(difc.ValidModes, 3, "should only have 3 valid modes")
}

func TestGetDefaultDIFCSinkServerIDs(t *testing.T) {
tests := []struct {
name string
envValue string
setEnv bool
expected string
}{
{
name: "no env var - returns empty string",
setEnv: false,
expected: "",
},
{
name: "env var set - returns value",
envValue: "safeoutputs,github",
setEnv: true,
expected: "safeoutputs,github",
},
{
name: "empty env var - returns empty string",
envValue: "",
setEnv: true,
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setEnv {
t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", tt.envValue)
} else {
t.Setenv("MCP_GATEWAY_GUARDS_SINK_SERVER_IDS", "")
}

result := getDefaultDIFCSinkServerIDs()
assert.Equal(t, tt.expected, result)
})
}
}

func TestParseDIFCSinkServerIDs(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -270,17 +230,3 @@ func TestBuildAllowOnlyPolicy(t *testing.T) {
require.Error(t, err)
})
}

func TestGetDefaultGuardPolicyInputs(t *testing.T) {
t.Setenv("MCP_GATEWAY_GUARD_POLICY_JSON", `{"allow-only":{"repos":"public","min-integrity":"none"}}`)
t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_PUBLIC", "1")
t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_OWNER", "lpcox")
t.Setenv("MCP_GATEWAY_ALLOWONLY_SCOPE_REPO", "gh-aw-mcpg")
t.Setenv("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", "unapproved")

assert.NotEmpty(t, getDefaultGuardPolicyJSON())
assert.True(t, getDefaultAllowOnlyScopePublic())
assert.Equal(t, "lpcox", getDefaultAllowOnlyOwner())
assert.Equal(t, "gh-aw-mcpg", getDefaultAllowOnlyRepo())
assert.Equal(t, "unapproved", getDefaultAllowOnlyMinIntegrity())
}
37 changes: 4 additions & 33 deletions internal/cmd/flags_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ import (
"github.com/spf13/cobra"
)

// Logging flag defaults
const (
defaultPayloadPathPrefix = "" // Empty by default - use actual filesystem path
)

// Logging flag variables
var (
logDir string
Expand All @@ -23,33 +18,9 @@ var (

func init() {
RegisterFlag(func(cmd *cobra.Command) {
cmd.Flags().StringVar(&logDir, "log-dir", getDefaultLogDir(), "Directory for log files (falls back to stdout if directory cannot be created)")
cmd.Flags().StringVar(&payloadDir, "payload-dir", getDefaultPayloadDir(), "Directory for storing large payload files (segmented by session ID)")
cmd.Flags().StringVar(&payloadPathPrefix, "payload-path-prefix", getDefaultPayloadPathPrefix(), "Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)")
cmd.Flags().IntVar(&payloadSizeThreshold, "payload-size-threshold", getDefaultPayloadSizeThreshold(), "Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored, smaller ones returned inline")
cmd.Flags().StringVar(&logDir, "log-dir", envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir), "Directory for log files (falls back to stdout if directory cannot be created)")
cmd.Flags().StringVar(&payloadDir, "payload-dir", envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_DIR", config.DefaultPayloadDir), "Directory for storing large payload files (segmented by session ID)")
cmd.Flags().StringVar(&payloadPathPrefix, "payload-path-prefix", envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", ""), "Path prefix to use when returning payloadPath to clients (allows remapping host paths to client/agent container paths)")
cmd.Flags().IntVar(&payloadSizeThreshold, "payload-size-threshold", envutil.GetEnvInt("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", config.DefaultPayloadSizeThreshold), "Size threshold (in bytes) for storing payloads to disk. Payloads larger than this are stored, smaller ones returned inline")
})
}

// getDefaultLogDir returns the default log directory, checking MCP_GATEWAY_LOG_DIR
// environment variable first, then falling back to the hardcoded default
func getDefaultLogDir() string {
return envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir)
}

// getDefaultPayloadDir returns the default payload directory, checking MCP_GATEWAY_PAYLOAD_DIR
// environment variable first, then falling back to the hardcoded default
func getDefaultPayloadDir() string {
return envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_DIR", config.DefaultPayloadDir)
}

// getDefaultPayloadPathPrefix returns the default payload path prefix, checking MCP_GATEWAY_PAYLOAD_PATH_PREFIX
// environment variable first, then falling back to the hardcoded default
func getDefaultPayloadPathPrefix() string {
return envutil.GetEnvString("MCP_GATEWAY_PAYLOAD_PATH_PREFIX", defaultPayloadPathPrefix)
}

// getDefaultPayloadSizeThreshold returns the default payload size threshold, checking
// MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD environment variable first, then falling back to the hardcoded default
func getDefaultPayloadSizeThreshold() int {
return envutil.GetEnvInt("MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD", config.DefaultPayloadSizeThreshold)
}
Loading
Loading