diff --git a/cmd/kh/main.go b/cmd/kh/main.go index 4a95b18..954618e 100644 --- a/cmd/kh/main.go +++ b/cmd/kh/main.go @@ -43,9 +43,20 @@ func main() { return hosts.ActiveHost(flagHost, envHost) } + // resolveOrgFlag returns the --org flag value if set, or empty string. + resolveOrgFlag := func() string { + if rootCmd != nil { + if f := rootCmd.PersistentFlags().Lookup("org"); f != nil { + return f.Value.String() + } + } + return "" + } + f := &cmdutil.Factory{ AppVersion: version.Version, IOStreams: ios, + OrgID: resolveOrgFlag, Config: func() (config.Config, error) { return config.ReadConfig() }, @@ -69,11 +80,12 @@ func main() { entry, _ := hosts.HostEntry(activeHost) return khhttp.NewClient(khhttp.ClientOptions{ - Host: activeHost, - Token: resolved.Token, - Headers: entry.Headers, - IOStreams: ios, - AppVersion: version.Version, + Host: activeHost, + Token: resolved.Token, + Headers: entry.Headers, + OrgOverride: resolveOrgFlag(), + IOStreams: ios, + AppVersion: version.Version, }), nil }, } diff --git a/cmd/root.go b/cmd/root.go index 3853220..0a8e1c7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -39,6 +39,7 @@ func NewRootCmd(f *cmdutil.Factory) *cobra.Command { cmd.PersistentFlags().BoolP("yes", "y", false, "Skip confirmation prompts") cmd.PersistentFlags().Bool("no-color", false, "Disable color output") cmd.PersistentFlags().StringP("host", "H", "", "KeeperHub host (default: app.keeperhub.com)") + cmd.PersistentFlags().String("org", "", "Organization ID to use (overrides default from auth)") cmd.AddCommand(action.NewActionCmd(f)) cmd.AddCommand(auth.NewAuthCmd(f)) diff --git a/cmd/root_test.go b/cmd/root_test.go index f91c44b..a7c11b2 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -84,6 +84,26 @@ func TestNewRootCmdHasHostFlag(t *testing.T) { assert.Equal(t, "string", flag.Value.Type()) } +func TestNewRootCmdHasOrgFlag(t *testing.T) { + f := newTestFactory() + root := cmd.NewRootCmd(f) + flag := root.PersistentFlags().Lookup("org") + require.NotNil(t, flag) + assert.Equal(t, "string", flag.Value.Type()) +} + +func TestRootCmdParsesOrgFlag(t *testing.T) { + f := newTestFactory() + root := cmd.NewRootCmd(f) + root.SetArgs([]string{"--org", "org_abc123", "--help"}) + err := root.Execute() + assert.NoError(t, err) + + orgVal, err := root.PersistentFlags().GetString("org") + require.NoError(t, err) + assert.Equal(t, "org_abc123", orgVal) +} + func TestHostFlagHasShorthandH(t *testing.T) { f := newTestFactory() root := cmd.NewRootCmd(f) diff --git a/cmd/serve/serve_test.go b/cmd/serve/serve_test.go index a182749..6676dfd 100644 --- a/cmd/serve/serve_test.go +++ b/cmd/serve/serve_test.go @@ -185,7 +185,7 @@ func TestBuildInputSchema(t *testing.T) { props, ok := schema["properties"].(map[string]any) require.True(t, ok, "schema must have a properties map") - assert.Len(t, props, 3, "expected 3 properties (2 required + 1 optional)") + assert.Len(t, props, 4, "expected 4 properties (2 required + 1 optional + organizationId)") required, ok := schema["required"].([]string) require.True(t, ok, "schema must have a required slice") @@ -196,7 +196,7 @@ func TestBuildInputSchema(t *testing.T) { assert.Equal(t, []string{"amount", "network"}, required) // Verify property structure. - for _, name := range []string{"network", "amount", "memo"} { + for _, name := range []string{"network", "amount", "memo", "organizationId"} { prop, exists := props[name] assert.True(t, exists, "property %q must exist", name) propMap, ok := prop.(map[string]any) @@ -262,6 +262,149 @@ func TestMakeToolHandler_ExecutesAction(t *testing.T) { assert.Contains(t, textContent.Text, "0xabc", "response body must be in TextContent") } +// TestMakeToolHandler_ForwardsOrgHeader verifies that when the tool arguments +// include an organizationId field, the handler sets the X-Organization-Id +// header on the outgoing HTTP request. +func TestMakeToolHandler_ForwardsOrgHeader(t *testing.T) { + var gotOrgHeader string + var gotBody []byte + + actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/mcp/schemas" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}}) + return + } + gotOrgHeader = r.Header.Get("X-Organization-Id") + gotBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer actionServer.Close() + + ios, _, _, _ := iostreams.Test() + f := newServeFactory(actionServer, ios) + + handler := serve.MakeToolHandler(f, "web3/transfer") + + args := map[string]any{"network": "1", "organizationId": "org_xyz789"} + argsJSON, err := json.Marshal(args) + require.NoError(t, err) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "web3_transfer", + Arguments: argsJSON, + }, + } + + ctx := context.Background() + result, err := handler(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "org_xyz789", gotOrgHeader, "handler must forward organizationId as X-Organization-Id header") + + var sentBody map[string]any + require.NoError(t, json.Unmarshal(gotBody, &sentBody)) + assert.NotContains(t, sentBody, "organizationId", "organizationId must not appear in the POST body") + assert.Contains(t, sentBody, "network", "other args must still be present in body") +} + +// TestMakeToolHandler_ToolArgWinsOverOrgFlag verifies that when the tool arguments +// include an organizationId field, it takes precedence over the --org flag value +// set on the HTTP client. +func TestMakeToolHandler_ToolArgWinsOverOrgFlag(t *testing.T) { + var gotOrgHeader string + + actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/mcp/schemas" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}}) + return + } + gotOrgHeader = r.Header.Get("X-Organization-Id") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer actionServer.Close() + + ios, _, _, _ := iostreams.Test() + client := khhttp.NewClient(khhttp.ClientOptions{ + Host: actionServer.URL, + AppVersion: "1.0.0", + OrgOverride: "flag-org", + }) + f := &cmdutil.Factory{ + AppVersion: "1.0.0", + IOStreams: ios, + HTTPClient: func() (*khhttp.Client, error) { return client, nil }, + Config: func() (config.Config, error) { return config.Config{DefaultHost: actionServer.URL}, nil }, + BaseURL: func() string { return actionServer.URL }, + } + + handler := serve.MakeToolHandler(f, "web3/transfer") + + args := map[string]any{"network": "1", "organizationId": "tool-arg-org"} + argsJSON, err := json.Marshal(args) + require.NoError(t, err) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "web3_transfer", + Arguments: argsJSON, + }, + } + + ctx := context.Background() + result, err := handler(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "tool-arg-org", gotOrgHeader, "tool arg organizationId must win over --org flag") +} + +// TestMakeToolHandler_NoOrgHeaderWhenAbsent verifies that when the tool arguments +// do not include organizationId, no X-Organization-Id header is sent. +func TestMakeToolHandler_NoOrgHeaderWhenAbsent(t *testing.T) { + var gotOrgHeader string + + actionServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/mcp/schemas" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{"actions": map[string]interface{}{}}) + return + } + gotOrgHeader = r.Header.Get("X-Organization-Id") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer actionServer.Close() + + ios, _, _, _ := iostreams.Test() + f := newServeFactory(actionServer, ios) + + handler := serve.MakeToolHandler(f, "web3/transfer") + + args := map[string]any{"network": "1"} + argsJSON, err := json.Marshal(args) + require.NoError(t, err) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "web3_transfer", + Arguments: argsJSON, + }, + } + + ctx := context.Background() + result, err := handler(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "", gotOrgHeader, "handler must not send X-Organization-Id when organizationId is absent") +} + // TestToolsAreFromSchema_NoneHardcoded verifies that the number of registered // tools equals the number of actions in the schema -- no hardcoded tools exist. func TestToolsAreFromSchema_NoneHardcoded(t *testing.T) { diff --git a/cmd/serve/tools.go b/cmd/serve/tools.go index 572149d..558ef64 100644 --- a/cmd/serve/tools.go +++ b/cmd/serve/tools.go @@ -35,6 +35,11 @@ func BuildInputSchema(action ActionSchema) map[string]any { } } + properties["organizationId"] = map[string]any{ + "type": "string", + "description": "Organization ID to use (overrides default from auth)", + } + return map[string]any{ "type": "object", "properties": properties, @@ -77,6 +82,12 @@ func MakeToolHandler(f *cmdutil.Factory, actionType string) mcp.ToolHandler { return nil, fmt.Errorf("unmarshaling arguments: %w", err) } } + if args == nil { + args = make(map[string]any) + } + + orgID := getStringArg(args, "organizationId") + delete(args, "organizationId") bodyBytes, err := json.Marshal(args) if err != nil { @@ -95,6 +106,10 @@ func MakeToolHandler(f *cmdutil.Factory, actionType string) mcp.ToolHandler { } httpReq.Header.Set("Content-Type", "application/json") + if orgID != "" { + httpReq.Header.Set("X-Organization-Id", orgID) + } + resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("executing request: %w", err) @@ -177,6 +192,10 @@ func makeStaticHandler( httpReq.Header.Set("Content-Type", "application/json") } + if orgID := getStringArg(args, "organizationId"); orgID != "" { + httpReq.Header.Set("X-Organization-Id", orgID) + } + resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("executing request: %w", err) @@ -205,6 +224,23 @@ func makeStaticHandler( } } +// orgProperty is the JSON Schema fragment for the optional organizationId field +// shared by all MCP tools. +var orgProperty = map[string]any{ + "type": "string", + "description": "Organization ID to use (overrides default from auth)", +} + +// withOrgField returns a copy of the input schema with the organizationId +// optional property added to its properties map. +func withOrgField(schema map[string]any) map[string]any { + props, ok := schema["properties"].(map[string]any) + if ok { + props["organizationId"] = orgProperty + } + return schema +} + // registerStaticTools registers workflow management and execution tools that // call KeeperHub API endpoints directly (not via /api/execute/). func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { @@ -212,7 +248,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_list", Description: "List all workflows in the current organization", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "limit": map[string]any{ @@ -220,7 +256,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { "description": "Maximum number of workflows to return", }, }, - }, + }), }, makeStaticHandler(f, http.MethodGet, func(args map[string]any, baseURL string) string { u := baseURL + "/api/workflows" if limit := getStringArg(args, "limit"); limit != "" { @@ -233,7 +269,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_get", Description: "Get a workflow by ID, including its nodes and edges", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "workflow_id": map[string]any{ @@ -242,7 +278,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"workflow_id"}, - }, + }), }, makeStaticHandler(f, http.MethodGet, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflows/" + getStringArg(args, "workflow_id") }, nil)) @@ -251,7 +287,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_create", Description: "Create a new workflow", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "name": map[string]any{ @@ -272,7 +308,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"name"}, - }, + }), }, makeStaticHandler(f, http.MethodPost, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflows/create" }, func(args map[string]any) ([]byte, error) { @@ -303,7 +339,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_update", Description: "Update an existing workflow", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "workflow_id": map[string]any{ @@ -328,7 +364,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"workflow_id"}, - }, + }), }, makeStaticHandler(f, http.MethodPatch, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflows/" + getStringArg(args, "workflow_id") }, func(args map[string]any) ([]byte, error) { @@ -360,7 +396,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_delete", Description: "Delete a workflow by ID. Use force=true to delete workflows that have execution history.", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "workflow_id": map[string]any{ @@ -373,7 +409,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"workflow_id"}, - }, + }), }, makeStaticHandler(f, http.MethodDelete, func(args map[string]any, baseURL string) string { u := baseURL + "/api/workflows/" + getStringArg(args, "workflow_id") if force, ok := args["force"]; ok && force == true { @@ -386,7 +422,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "workflow_execute", Description: "Execute a workflow by ID", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "workflow_id": map[string]any{ @@ -399,7 +435,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"workflow_id"}, - }, + }), }, makeStaticHandler(f, http.MethodPost, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflow/" + getStringArg(args, "workflow_id") + "/execute" }, func(args map[string]any) ([]byte, error) { @@ -417,7 +453,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "execution_status", Description: "Get the status of a workflow execution", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "execution_id": map[string]any{ @@ -426,7 +462,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"execution_id"}, - }, + }), }, makeStaticHandler(f, http.MethodGet, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflows/executions/" + getStringArg(args, "execution_id") + "/status" }, nil)) @@ -435,7 +471,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { server.AddTool(&mcp.Tool{ Name: "execution_logs", Description: "Get the logs for a workflow execution", - InputSchema: map[string]any{ + InputSchema: withOrgField(map[string]any{ "type": "object", "properties": map[string]any{ "execution_id": map[string]any{ @@ -444,7 +480,7 @@ func registerStaticTools(server *mcp.Server, f *cmdutil.Factory) { }, }, "required": []string{"execution_id"}, - }, + }), }, makeStaticHandler(f, http.MethodGet, func(args map[string]any, baseURL string) string { return baseURL + "/api/workflows/executions/" + getStringArg(args, "execution_id") + "/logs" }, nil)) diff --git a/internal/http/client.go b/internal/http/client.go index 4b83a4f..357fbaf 100644 --- a/internal/http/client.go +++ b/internal/http/client.go @@ -22,6 +22,10 @@ type ClientOptions struct { // (e.g. Cloudflare Access headers loaded from hosts.yml). Headers map[string]string + // OrgOverride is the organization ID from the --org flag. + // When non-empty, X-Organization-Id is sent on every request. + OrgOverride string + // IOStreams provides the ErrOut writer for version warnings. IOStreams *iostreams.IOStreams @@ -32,11 +36,12 @@ type ClientOptions struct { // Client is a retryable HTTP client that injects version and auth headers // on every outgoing request. type Client struct { - inner *retryablehttp.Client - appVersion string - token string - headers map[string]string - ios *iostreams.IOStreams + inner *retryablehttp.Client + appVersion string + token string + headers map[string]string + orgOverride string + ios *iostreams.IOStreams } // NewClient creates a Client wrapping hashicorp/go-retryablehttp with @@ -55,11 +60,12 @@ func NewClient(opts ClientOptions) *Client { } return &Client{ - inner: rc, - appVersion: opts.AppVersion, - token: opts.Token, - headers: opts.Headers, - ios: opts.IOStreams, + inner: rc, + appVersion: opts.AppVersion, + token: opts.Token, + headers: opts.Headers, + orgOverride: opts.OrgOverride, + ios: opts.IOStreams, } } @@ -73,6 +79,10 @@ func (c *Client) Do(req *retryablehttp.Request) (*http.Response, error) { req.Header.Set("Authorization", "Bearer "+c.token) } + if c.orgOverride != "" && req.Header.Get("X-Organization-Id") == "" { + req.Header.Set("X-Organization-Id", c.orgOverride) + } + for k, v := range c.headers { req.Header.Set(k, v) } diff --git a/internal/http/client_test.go b/internal/http/client_test.go index 917aa3c..8d2a798 100644 --- a/internal/http/client_test.go +++ b/internal/http/client_test.go @@ -92,6 +92,58 @@ func TestClientSetsAuthorizationHeader(t *testing.T) { assert.Equal(t, "Bearer my-secret-token", gotAuth) } +func TestClientSetsOrgOverrideHeader(t *testing.T) { + var gotOrgHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotOrgHeader = r.Header.Get("X-Organization-Id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ios, _, _, _ := iostreams.Test() + client := khhttp.NewClient(khhttp.ClientOptions{ + Host: srv.URL, + AppVersion: "1.0.0", + OrgOverride: "org_abc123", + IOStreams: ios, + }) + + req, err := retryablehttp.NewRequest(http.MethodGet, srv.URL+"/test", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "org_abc123", gotOrgHeader) +} + +func TestClientNoOrgHeaderWhenOrgOverrideEmpty(t *testing.T) { + var gotOrgHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotOrgHeader = r.Header.Get("X-Organization-Id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ios, _, _, _ := iostreams.Test() + client := khhttp.NewClient(khhttp.ClientOptions{ + Host: srv.URL, + AppVersion: "1.0.0", + OrgOverride: "", + IOStreams: ios, + }) + + req, err := retryablehttp.NewRequest(http.MethodGet, srv.URL+"/test", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "", gotOrgHeader) +} + func TestClientNoAuthorizationHeaderWhenTokenEmpty(t *testing.T) { var gotAuth string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -213,6 +265,33 @@ func TestCheckVersionNoWarningForDevVersion(t *testing.T) { assert.Empty(t, errOut.String()) } +func TestClientOrgOverride_DoesNotOverwriteExistingHeader(t *testing.T) { + var gotOrgHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotOrgHeader = r.Header.Get("X-Organization-Id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ios, _, _, _ := iostreams.Test() + client := khhttp.NewClient(khhttp.ClientOptions{ + Host: srv.URL, + AppVersion: "1.0.0", + OrgOverride: "flag-org", + IOStreams: ios, + }) + + req, err := retryablehttp.NewRequest(http.MethodGet, srv.URL+"/test", nil) + require.NoError(t, err) + req.Header.Set("X-Organization-Id", "tool-arg-org") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "tool-arg-org", gotOrgHeader, "per-request X-Organization-Id must not be overwritten by OrgOverride") +} + func TestSemverLessThan(t *testing.T) { tests := []struct { current string diff --git a/pkg/cmdutil/factory.go b/pkg/cmdutil/factory.go index 039f930..ccbcea3 100644 --- a/pkg/cmdutil/factory.go +++ b/pkg/cmdutil/factory.go @@ -25,6 +25,10 @@ type Factory struct { // ResolveHost when a cobra.Command is not available (e.g. MCP serve mode). BaseURL func() string + // OrgID returns the organization ID override from the --org flag. + // Returns an empty string when no override is set. + OrgID func() string + // IOStreams provides the standard input/output streams. IOStreams *iostreams.IOStreams }