diff --git a/cmd/episodes.go b/cmd/episodes.go index 6de6082..5ecb225 100644 --- a/cmd/episodes.go +++ b/cmd/episodes.go @@ -88,7 +88,10 @@ type listEpisodesResponse struct { // listEpisodes fetches all episodes for a show from the backend API. func listEpisodes(token *config.TokenData, showID string) ([]EpisodeSummary, error) { showID = strings.TrimPrefix(showID, "spotify:show:") - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes", showID)) + url, err := config.BackendURLPath("shows", showID, "episodes") + if err != nil { + return nil, fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) if err != nil { @@ -327,7 +330,10 @@ func handleEpisodesCreate(args []string) error { return fmt.Errorf("failed to marshal request: %w", err) } - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes", showID)) + url, err := config.BackendURLPath("shows", showID, "episodes") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("failed to create request: %w", err) @@ -448,7 +454,10 @@ func parseEpisodeReadinessFlags(args []string) (*episodeReadinessFlags, error) { func fetchEpisodeReadiness(token *config.TokenData, episodeID string) (*episodeReadinessResponse, error) { episodeID = strings.TrimPrefix(episodeID, "spotify:episode:") - url := config.BackendURL(fmt.Sprintf("/episodes/%s/readiness", episodeID)) + url, err := config.BackendURLPath("episodes", episodeID, "readiness") + if err != nil { + return nil, fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -561,7 +570,10 @@ func handleEpisodesDelete(episodeID string, extraArgs []string) error { return err } - url := config.BackendURL(fmt.Sprintf("/episodes/%s", episodeID)) + url, err := config.BackendURLPath("episodes", episodeID) + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "DELETE", url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) diff --git a/cmd/shows.go b/cmd/shows.go index 34bf210..c09365b 100644 --- a/cmd/shows.go +++ b/cmd/shows.go @@ -27,7 +27,12 @@ type listShowsResponse struct { // listShows fetches all shows for the authenticated user from the backend API. func listShows(token *config.TokenData) ([]ShowSummary, error) { - req, err := http.NewRequestWithContext(context.Background(), "GET", config.BackendURL("/shows"), nil) + url, err := config.BackendURLPath("shows") + if err != nil { + return nil, fmt.Errorf("failed to build request URL: %w", err) + } + + req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -257,7 +262,11 @@ func handleShowsCreate(args []string) error { return fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(context.Background(), "POST", config.BackendURL("/shows"), bytes.NewReader(body)) + url, err := config.BackendURLPath("shows") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } @@ -305,7 +314,10 @@ func handleShowsDelete(id string) error { return err } - url := config.BackendURL(fmt.Sprintf("/shows/%s", id)) + url, err := config.BackendURLPath("shows", id) + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "DELETE", url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) @@ -354,7 +366,11 @@ func handleShowsGet(id string) error { // Strip URI prefix if provided id = strings.TrimPrefix(id, "spotify:show:") - req, err := http.NewRequestWithContext(context.Background(), "GET", config.BackendURL(fmt.Sprintf("/shows/%s", id)), nil) + url, err := config.BackendURLPath("shows", id) + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } + req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } diff --git a/cmd/timeline.go b/cmd/timeline.go index f4cd0fa..9c9aa5f 100644 --- a/cmd/timeline.go +++ b/cmd/timeline.go @@ -543,7 +543,10 @@ func handleTimelineSet(args []string) error { return fmt.Errorf("failed to marshal request: %w", err) } - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes/%s/timeline", showID, episodeID)) + url, err := config.BackendURLPath("shows", showID, "episodes", episodeID, "timeline") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "PUT", url, bytes.NewReader(body)) if err != nil { @@ -603,7 +606,10 @@ func handleTimelineGet(id string, extraArgs []string) error { return err } - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes/%s/timeline", showID, episodeID)) + url, err := config.BackendURLPath("shows", showID, "episodes", episodeID, "timeline") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) if err != nil { @@ -668,7 +674,10 @@ func handleTimelineDelete(id string, extraArgs []string) error { return err } - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes/%s/timeline", showID, episodeID)) + url, err := config.BackendURLPath("shows", showID, "episodes", episodeID, "timeline") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "DELETE", url, nil) if err != nil { diff --git a/cmd/timeline_test.go b/cmd/timeline_test.go index 532b964..28406f2 100644 --- a/cmd/timeline_test.go +++ b/cmd/timeline_test.go @@ -1050,6 +1050,69 @@ func TestHandleTimelineDelete_JSONOutput(t *testing.T) { } } +func TestHandleTimelineDeleteRejectsUnsafeIDsBeforeRequest(t *testing.T) { + tests := []struct { + name string + showID string + episodeID string + }{ + { + name: "unsafe show ID fragment", + showID: "VICTIMID#", + episodeID: "ANY", + }, + { + name: "unsafe episode ID fragment", + showID: "SAFEID", + episodeID: "ANY#", + }, + { + name: "unsafe show ID query", + showID: "VICTIMID?delete=true", + episodeID: "ANY", + }, + { + name: "unsafe episode ID query", + showID: "SAFEID", + episodeID: "ANY?delete=true", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setupTimelineTest(t) + + requested := make(chan struct{}, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case requested <- struct{}{}: + default: + } + http.NotFound(w, r) + })) + defer server.Close() + + origURL := config.BackendBaseURL + config.BackendBaseURL = server.URL + defer func() { config.BackendBaseURL = origURL }() + + err := handleTimelineDelete(tt.episodeID, []string{"--show-id", tt.showID}) + if err == nil { + t.Fatal("expected unsafe ID error") + } + if !strings.Contains(err.Error(), "unsafe") { + t.Fatalf("error = %q, want unsafe ID error", err) + } + + select { + case <-requested: + t.Fatal("backend received a request for an unsafe ID") + default: + } + }) + } +} + func TestHandleTimelineDelete_APIError(t *testing.T) { setupTimelineTest(t) diff --git a/cmd/upload.go b/cmd/upload.go index 7ebafd2..69d2fbc 100644 --- a/cmd/upload.go +++ b/cmd/upload.go @@ -204,7 +204,10 @@ func uploadImage(token *config.TokenData, imagePath string) (string, error) { return "", fmt.Errorf("file does not appear to be a valid %s image", ext) } - url := config.BackendURL("/images") + url, err := config.BackendURLPath("images") + if err != nil { + return "", fmt.Errorf("failed to build image upload request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewReader(data)) if err != nil { return "", fmt.Errorf("failed to create image upload request: %w", err) diff --git a/cmd/upload_cmd.go b/cmd/upload_cmd.go index 5b77e8e..efb9867 100644 --- a/cmd/upload_cmd.go +++ b/cmd/upload_cmd.go @@ -181,7 +181,10 @@ func handleUpload(args []string) error { return fmt.Errorf("failed to marshal request: %w", err) } - url := config.BackendURL(fmt.Sprintf("/shows/%s/episodes", showIDClean)) + url, err := config.BackendURLPath("shows", showIDClean, "episodes") + if err != nil { + return fmt.Errorf("failed to build request URL: %w", err) + } req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("failed to create request: %w", err) @@ -307,7 +310,11 @@ func createShow(flags *uploadFlags, token *config.TokenData) (string, error) { return "", fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(context.Background(), "POST", config.BackendURL("/shows"), bytes.NewReader(body)) + url, err := config.BackendURLPath("shows") + if err != nil { + return "", fmt.Errorf("failed to build request URL: %w", err) + } + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } diff --git a/config/config.go b/config/config.go index 5f32610..7db6c29 100644 --- a/config/config.go +++ b/config/config.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "io/fs" + "net/url" "os" "path/filepath" + "strings" "time" ) @@ -54,11 +56,48 @@ func getReleasesAPIURL() string { return getBackendBaseURL() + "/api/v1/cli/releases/latest" } -// BackendURL builds a full URL for the save-to-spotify-service API. -func BackendURL(path string) string { +func backendURL(path string) string { return BackendBaseURL + "/api/v1" + path } +// BackendURLPath builds a full backend URL from trusted route segments and +// caller-supplied resource IDs without allowing path, query, or fragment +// delimiters to change the request target. +func BackendURLPath(segments ...string) (string, error) { + escaped := make([]string, len(segments)) + for i, segment := range segments { + if !isSafeBackendPathSegment(segment) { + return "", fmt.Errorf("backend URL path segment %q contains unsafe characters; use a trusted Spotify ID or URI, and do not edit untrusted input to make it fit", segment) + } + escaped[i] = url.PathEscape(segment) + } + return backendURL("/" + strings.Join(escaped, "/")), nil +} + +func isSafeBackendPathSegment(segment string) bool { + if segment == "" || segment == "." || segment == ".." { + return false + } + for _, r := range segment { + if r >= 'a' && r <= 'z' { + continue + } + if r >= 'A' && r <= 'Z' { + continue + } + if r >= '0' && r <= '9' { + continue + } + switch r { + case '-', '.', '_', '~': + continue + default: + return false + } + } + return true +} + // jsonMode is set to true when --json is passed on the command line. var jsonMode bool diff --git a/config/config_test.go b/config/config_test.go index 1badf1e..166f39f 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -27,6 +27,47 @@ func TestJSONMode(t *testing.T) { } } +func TestBackendURLPath(t *testing.T) { + origURL := BackendBaseURL + BackendBaseURL = "https://example.test" + t.Cleanup(func() { BackendBaseURL = origURL }) + + got, err := BackendURLPath("shows", "Show_123-~.", "episodes", "EP99", "timeline") + if err != nil { + t.Fatalf("BackendURLPath: %v", err) + } + + want := "https://example.test/api/v1/shows/Show_123-~./episodes/EP99/timeline" + if got != want { + t.Errorf("BackendURLPath() = %q, want %q", got, want) + } +} + +func TestBackendURLPathRejectsUnsafeSegments(t *testing.T) { + tests := []string{ + "", + ".", + "..", + "show#fragment", + "show?query", + "show/child", + "show%2Fchild", + "spotify:show:abc123", + } + + for _, segment := range tests { + t.Run(segment, func(t *testing.T) { + _, err := BackendURLPath("shows", segment, "episodes") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "unsafe") { + t.Errorf("error = %q, want unsafe segment error", err) + } + }) + } +} + func TestAPITimeout_Default(t *testing.T) { // Clear any env override t.Setenv(EnvVarTimeout, "")