Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions cmd/kh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,20 @@ func main() {
return hosts.ActiveHost(flagHost, envHost)
}

// resolveOrgFlag returns the --org flag value if set, or empty string.
resolveOrgFlag := func() string {
if rootCmd != nil {
if f := rootCmd.PersistentFlags().Lookup("org"); f != nil {
return f.Value.String()
}
}
return ""
}

f := &cmdutil.Factory{
AppVersion: version.Version,
IOStreams: ios,
OrgID: resolveOrgFlag,
Config: func() (config.Config, error) {
return config.ReadConfig()
},
Expand All @@ -69,11 +80,12 @@ func main() {
entry, _ := hosts.HostEntry(activeHost)

return khhttp.NewClient(khhttp.ClientOptions{
Host: activeHost,
Token: resolved.Token,
Headers: entry.Headers,
IOStreams: ios,
AppVersion: version.Version,
Host: activeHost,
Token: resolved.Token,
Headers: entry.Headers,
OrgOverride: resolveOrgFlag(),
IOStreams: ios,
AppVersion: version.Version,
}), nil
},
}
Expand Down
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func NewRootCmd(f *cmdutil.Factory) *cobra.Command {
cmd.PersistentFlags().BoolP("yes", "y", false, "Skip confirmation prompts")
cmd.PersistentFlags().Bool("no-color", false, "Disable color output")
cmd.PersistentFlags().StringP("host", "H", "", "KeeperHub host (default: app.keeperhub.com)")
cmd.PersistentFlags().String("org", "", "Organization ID to use (overrides default from auth)")

cmd.AddCommand(action.NewActionCmd(f))
cmd.AddCommand(auth.NewAuthCmd(f))
Expand Down
20 changes: 20 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ func TestNewRootCmdHasHostFlag(t *testing.T) {
assert.Equal(t, "string", flag.Value.Type())
}

func TestNewRootCmdHasOrgFlag(t *testing.T) {
f := newTestFactory()
root := cmd.NewRootCmd(f)
flag := root.PersistentFlags().Lookup("org")
require.NotNil(t, flag)
assert.Equal(t, "string", flag.Value.Type())
}

func TestRootCmdParsesOrgFlag(t *testing.T) {
f := newTestFactory()
root := cmd.NewRootCmd(f)
root.SetArgs([]string{"--org", "org_abc123", "--help"})
err := root.Execute()
assert.NoError(t, err)

orgVal, err := root.PersistentFlags().GetString("org")
require.NoError(t, err)
assert.Equal(t, "org_abc123", orgVal)
}

func TestHostFlagHasShorthandH(t *testing.T) {
f := newTestFactory()
root := cmd.NewRootCmd(f)
Expand Down
147 changes: 145 additions & 2 deletions cmd/serve/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func TestBuildInputSchema(t *testing.T) {

props, ok := schema["properties"].(map[string]any)
require.True(t, ok, "schema must have a properties map")
assert.Len(t, props, 3, "expected 3 properties (2 required + 1 optional)")
assert.Len(t, props, 4, "expected 4 properties (2 required + 1 optional + organizationId)")

required, ok := schema["required"].([]string)
require.True(t, ok, "schema must have a required slice")
Expand All @@ -196,7 +196,7 @@ func TestBuildInputSchema(t *testing.T) {
assert.Equal(t, []string{"amount", "network"}, required)

// Verify property structure.
for _, name := range []string{"network", "amount", "memo"} {
for _, name := range []string{"network", "amount", "memo", "organizationId"} {
prop, exists := props[name]
assert.True(t, exists, "property %q must exist", name)
propMap, ok := prop.(map[string]any)
Expand Down Expand Up @@ -262,6 +262,149 @@ func TestMakeToolHandler_ExecutesAction(t *testing.T) {
assert.Contains(t, textContent.Text, "0xabc", "response body must be in TextContent")
}

// TestMakeToolHandler_ForwardsOrgHeader verifies that when the tool arguments
// include an organizationId field, the handler sets the X-Organization-Id
// header on the outgoing HTTP request.
func TestMakeToolHandler_ForwardsOrgHeader(t *testing.T) {
var gotOrgHeader string
var gotBody []byte

actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/mcp/schemas" {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}})
return
}
gotOrgHeader = r.Header.Get("X-Organization-Id")
gotBody, _ = io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer actionServer.Close()

ios, _, _, _ := iostreams.Test()
f := newServeFactory(actionServer, ios)

handler := serve.MakeToolHandler(f, "web3/transfer")

args := map[string]any{"network": "1", "organizationId": "org_xyz789"}
argsJSON, err := json.Marshal(args)
require.NoError(t, err)

req := &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "web3_transfer",
Arguments: argsJSON,
},
}

ctx := context.Background()
result, err := handler(ctx, req)
require.NoError(t, err)
require.NotNil(t, result)

assert.Equal(t, "org_xyz789", gotOrgHeader, "handler must forward organizationId as X-Organization-Id header")

var sentBody map[string]any
require.NoError(t, json.Unmarshal(gotBody, &sentBody))
assert.NotContains(t, sentBody, "organizationId", "organizationId must not appear in the POST body")
assert.Contains(t, sentBody, "network", "other args must still be present in body")
}

// TestMakeToolHandler_ToolArgWinsOverOrgFlag verifies that when the tool arguments
// include an organizationId field, it takes precedence over the --org flag value
// set on the HTTP client.
func TestMakeToolHandler_ToolArgWinsOverOrgFlag(t *testing.T) {
var gotOrgHeader string

actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/mcp/schemas" {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}})
return
}
gotOrgHeader = r.Header.Get("X-Organization-Id")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer actionServer.Close()

ios, _, _, _ := iostreams.Test()
client := khhttp.NewClient(khhttp.ClientOptions{
Host: actionServer.URL,
AppVersion: "1.0.0",
OrgOverride: "flag-org",
})
f := &cmdutil.Factory{
AppVersion: "1.0.0",
IOStreams: ios,
HTTPClient: func() (*khhttp.Client, error) { return client, nil },
Config: func() (config.Config, error) { return config.Config{DefaultHost: actionServer.URL}, nil },
BaseURL: func() string { return actionServer.URL },
}

handler := serve.MakeToolHandler(f, "web3/transfer")

args := map[string]any{"network": "1", "organizationId": "tool-arg-org"}
argsJSON, err := json.Marshal(args)
require.NoError(t, err)

req := &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "web3_transfer",
Arguments: argsJSON,
},
}

ctx := context.Background()
result, err := handler(ctx, req)
require.NoError(t, err)
require.NotNil(t, result)

assert.Equal(t, "tool-arg-org", gotOrgHeader, "tool arg organizationId must win over --org flag")
}

// TestMakeToolHandler_NoOrgHeaderWhenAbsent verifies that when the tool arguments
// do not include organizationId, no X-Organization-Id header is sent.
func TestMakeToolHandler_NoOrgHeaderWhenAbsent(t *testing.T) {
var gotOrgHeader string

actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/mcp/schemas" {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}})
return
}
gotOrgHeader = r.Header.Get("X-Organization-Id")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer actionServer.Close()

ios, _, _, _ := iostreams.Test()
f := newServeFactory(actionServer, ios)

handler := serve.MakeToolHandler(f, "web3/transfer")

args := map[string]any{"network": "1"}
argsJSON, err := json.Marshal(args)
require.NoError(t, err)

req := &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "web3_transfer",
Arguments: argsJSON,
},
}

ctx := context.Background()
result, err := handler(ctx, req)
require.NoError(t, err)
require.NotNil(t, result)

assert.Equal(t, "", gotOrgHeader, "handler must not send X-Organization-Id when organizationId is absent")
}

// TestToolsAreFromSchema_NoneHardcoded verifies that the number of registered
// tools equals the number of actions in the schema -- no hardcoded tools exist.
func TestToolsAreFromSchema_NoneHardcoded(t *testing.T) {
Expand Down
Loading
Loading