From 9aea8f8e1f50f0d0af2f5b9087b0eed98a9d04f4 Mon Sep 17 00:00:00 2001 From: h1177h <2928932863@qq.com> Date: Fri, 8 May 2026 20:41:20 +0800 Subject: [PATCH 1/2] feat: add marketplace registry server --- cmd/anyclaw-registry/main.go | 93 ++++ cmd/anyclaw-registry/main_test.go | 79 +++ pkg/marketregistry/seed.go | 108 ++++ pkg/marketregistry/server.go | 550 ++++++++++++++++++ pkg/marketregistry/server_test.go | 370 ++++++++++++ pkg/marketregistry/storage.go | 169 ++++++ pkg/marketregistry/store.go | 899 ++++++++++++++++++++++++++++++ pkg/marketregistry/types.go | 194 +++++++ 8 files changed, 2462 insertions(+) create mode 100644 cmd/anyclaw-registry/main.go create mode 100644 cmd/anyclaw-registry/main_test.go create mode 100644 pkg/marketregistry/seed.go create mode 100644 pkg/marketregistry/server.go create mode 100644 pkg/marketregistry/server_test.go create mode 100644 pkg/marketregistry/storage.go create mode 100644 pkg/marketregistry/store.go create mode 100644 pkg/marketregistry/types.go diff --git a/cmd/anyclaw-registry/main.go b/cmd/anyclaw-registry/main.go new file mode 100644 index 00000000..cc252a91 --- /dev/null +++ b/cmd/anyclaw-registry/main.go @@ -0,0 +1,93 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/1024XEngineer/anyclaw/pkg/marketregistry" +) + +func main() { + if err := run(os.Args[1:]); err != nil { + log.Fatal(err) + } +} + +func run(args []string) error { + if len(args) == 0 { + args = []string{"serve"} + } + switch args[0] { + case "serve": + return serve(args[1:]) + case "-h", "--help", "help": + printUsage() + return nil + default: + return fmt.Errorf("unknown command %q", args[0]) + } +} + +func serve(args []string) error { + fs := flag.NewFlagSet("serve", flag.ContinueOnError) + addr := fs.String("addr", ":8791", "HTTP listen address") + dataDir := fs.String("data-dir", ".anyclaw-registry", "registry data directory") + dbDriver := fs.String("db-driver", "sqlite", "database/sql driver name") + dbDSN := fs.String("db-dsn", "", "database DSN; defaults to /registry.db for sqlite") + adminToken := fs.String("admin-token", os.Getenv("ANYCLAW_REGISTRY_ADMIN_TOKEN"), "admin bearer token; defaults to ANYCLAW_REGISTRY_ADMIN_TOKEN") + requireAdminToken := fs.Bool("require-admin-token", envBool("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", true), "fail startup when admin token is empty") + seed := fs.Bool("seed", true, "seed fixture artifacts when the registry is empty") + if err := fs.Parse(args); err != nil { + return err + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + server, err := marketregistry.NewServer(ctx, marketregistry.ServerConfig{ + Addr: *addr, + DataDir: *dataDir, + DBDriver: *dbDriver, + DBDSN: *dbDSN, + AdminToken: *adminToken, + RequireAdminToken: *requireAdminToken, + Seed: *seed, + }) + if err != nil { + return err + } + defer server.Close() + + log.Printf("anyclaw registry listening on %s, data_dir=%s", *addr, *dataDir) + err = server.StartWithContext(ctx) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +func printUsage() { + fmt.Println("Usage: anyclaw-registry serve [--addr :8791] [--data-dir .anyclaw-registry] [--db-driver sqlite] [--db-dsn path-or-url] [--admin-token token] [--require-admin-token=true] [--seed=true]") +} + +func envBool(name string, fallback bool) bool { + value := os.Getenv(name) + if value == "" { + return fallback + } + switch value { + case "1", "true", "TRUE", "True", "yes", "YES", "on", "ON": + return true + case "0", "false", "FALSE", "False", "no", "NO", "off", "OFF": + return false + default: + return fallback + } +} diff --git a/cmd/anyclaw-registry/main_test.go b/cmd/anyclaw-registry/main_test.go new file mode 100644 index 00000000..04d2f51a --- /dev/null +++ b/cmd/anyclaw-registry/main_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "bytes" + "io" + "os" + "strings" + "testing" +) + +func TestRunDefaultsToServeAndRequiresAdminToken(t *testing.T) { + t.Setenv("ANYCLAW_REGISTRY_ADMIN_TOKEN", "") + t.Setenv("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", "") + + err := run([]string{"serve", "--data-dir", t.TempDir(), "--seed=false"}) + if err == nil || !strings.Contains(err.Error(), "admin token is required") { + t.Fatalf("expected missing admin token error, got %v", err) + } +} + +func TestRunAllowsExplicitLocalRegistryWithoutAdminToken(t *testing.T) { + t.Setenv("ANYCLAW_REGISTRY_ADMIN_TOKEN", "") + t.Setenv("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", "") + + err := run([]string{"serve", "--addr", "127.0.0.1:bad", "--data-dir", t.TempDir(), "--seed=false", "--require-admin-token=false"}) + if err == nil || strings.Contains(err.Error(), "admin token is required") { + t.Fatalf("expected listen error after explicit insecure opt-out, got %v", err) + } +} + +func TestRunHelpAndUnknownCommand(t *testing.T) { + out := captureStdout(t, func() { + if err := run([]string{"help"}); err != nil { + t.Fatalf("help returned error: %v", err) + } + }) + if !strings.Contains(out, "anyclaw-registry serve") { + t.Fatalf("help output = %q", out) + } + if err := run([]string{"nope"}); err == nil || !strings.Contains(err.Error(), "unknown command") { + t.Fatalf("expected unknown command error, got %v", err) + } +} + +func TestEnvBoolParsesKnownValuesAndFallback(t *testing.T) { + t.Setenv("BOOL_VALUE", "yes") + if !envBool("BOOL_VALUE", false) { + t.Fatal("expected yes to parse true") + } + t.Setenv("BOOL_VALUE", "OFF") + if envBool("BOOL_VALUE", true) { + t.Fatal("expected OFF to parse false") + } + t.Setenv("BOOL_VALUE", "not-bool") + if !envBool("BOOL_VALUE", true) { + t.Fatal("expected invalid value to use fallback") + } +} + +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + original := os.Stdout + read, write, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = write + defer func() { os.Stdout = original }() + + fn() + if err := write.Close(); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + if _, err := io.Copy(&buf, read); err != nil { + t.Fatal(err) + } + return buf.String() +} diff --git a/pkg/marketregistry/seed.go b/pkg/marketregistry/seed.go new file mode 100644 index 00000000..8b8c0fdd --- /dev/null +++ b/pkg/marketregistry/seed.go @@ -0,0 +1,108 @@ +package marketregistry + +import ( + "context" + "time" +) + +func SeedIfEmpty(ctx context.Context, store *Store, storage *LocalStorage) error { + count, err := store.CountArtifacts(ctx) + if err != nil { + return err + } + if count > 0 { + return nil + } + return SeedFixtures(ctx, store, storage) +} + +func SeedFixtures(ctx context.Context, store *Store, storage *LocalStorage) error { + now := time.Now().UTC().Format(time.RFC3339) + fixtures := []Artifact{ + { + ID: "cloud.agent.code-reviewer", + Kind: ArtifactKindAgent, + Name: "Cloud Code Reviewer", + Summary: "Reviews local changes and highlights concrete risks before merge.", + DescriptionMD: "A marketplace fixture agent for review-oriented workflows.", + Version: "1.0.0", + LatestVersion: "1.0.0", + Source: defaultRegistrySourceID, + Publisher: "AnyClaw Labs", + RiskLevel: "medium", + TrustLevel: "verified", + Permissions: []string{"fs.read", "git.read"}, + Compatibility: Compatibility{AnyClawMin: "0.1.0", OS: []string{"windows", "linux", "darwin"}, Arch: []string{"amd64", "arm64"}}, + Tags: []string{"agent", "review", "quality"}, + HitSignals: []string{"code review", "风险检查", "pull request"}, + Score: 0.96, + UpdatedAt: now, + ManifestSummary: map[string]string{"entry": "agent/profile.json"}, + }, + { + ID: "cloud.skill.release-notes", + Kind: ArtifactKindSkill, + Name: "Release Notes Writer", + Summary: "Turns git history and issue notes into compact release notes.", + DescriptionMD: "A marketplace fixture skill for writing release notes from project context.", + Version: "1.0.0", + LatestVersion: "1.0.0", + Source: defaultRegistrySourceID, + Publisher: "AnyClaw Labs", + RiskLevel: "low", + TrustLevel: "verified", + Permissions: []string{"fs.read", "git.read"}, + Compatibility: Compatibility{AnyClawMin: "0.1.0", OS: []string{"windows", "linux", "darwin"}, Arch: []string{"amd64", "arm64"}}, + Tags: []string{"skill", "release", "writing"}, + HitSignals: []string{"release notes", "changelog", "发布说明"}, + Score: 0.94, + UpdatedAt: now, + ManifestSummary: map[string]string{"entry": "skill/SKILL.md"}, + }, + { + ID: "cloud.cli.repo-health", + Kind: ArtifactKindCLI, + Name: "Repo Health CLI", + Summary: "Runs a lightweight repository health check command.", + DescriptionMD: "A marketplace fixture CLI package for command binding tests.", + Version: "1.0.0", + LatestVersion: "1.0.0", + Source: defaultRegistrySourceID, + Publisher: "AnyClaw Labs", + RiskLevel: "medium", + TrustLevel: "verified", + Permissions: []string{"process.exec", "fs.read"}, + Compatibility: Compatibility{AnyClawMin: "0.1.0", OS: []string{"windows", "linux", "darwin"}, Arch: []string{"amd64", "arm64"}}, + Tags: []string{"cli", "health", "repository"}, + HitSignals: []string{"repo health", "诊断", "cli"}, + Score: 0.91, + UpdatedAt: now, + ManifestSummary: map[string]string{"command": "anyclaw-repo-health"}, + }, + } + + for _, artifact := range fixtures { + version := ArtifactVersion{ + ArtifactID: artifact.ID, + Version: artifact.LatestVersion, + ReleasedAt: now, + ChangelogMD: "Initial registry fixture.", + Compatibility: artifact.Compatibility, + Permissions: artifact.Permissions, + PermissionsDiff: artifact.Permissions, + } + info, err := storage.EnsurePackage(artifact, version) + if err != nil { + return err + } + version.SizeBytes = info.SizeBytes + version.ChecksumSHA256 = info.ChecksumSHA256 + version.StorageKey = info.StorageKey + artifact.SizeBytes = info.SizeBytes + artifact.ChecksumSHA256 = info.ChecksumSHA256 + if err := store.UpsertArtifact(ctx, artifact, []ArtifactVersion{version}); err != nil { + return err + } + } + return nil +} diff --git a/pkg/marketregistry/server.go b/pkg/marketregistry/server.go new file mode 100644 index 00000000..79f57532 --- /dev/null +++ b/pkg/marketregistry/server.go @@ -0,0 +1,550 @@ +package marketregistry + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +type ServerConfig struct { + Addr string + DataDir string + DBDriver string + DBDSN string + AdminToken string + RequireAdminToken bool + Seed bool + MaxRequestBodyBytes int64 +} + +type Server struct { + mux *http.ServeMux + store *Store + storage PackageStorage + addr string + adminToken string + maxRequestBodyBytes int64 +} + +func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { + if cfg.RequireAdminToken && strings.TrimSpace(cfg.AdminToken) == "" { + return nil, fmt.Errorf("admin token is required") + } + store, err := OpenStoreWithConfig(ctx, StoreConfig{DataDir: cfg.DataDir, Driver: cfg.DBDriver, DSN: cfg.DBDSN}) + if err != nil { + return nil, err + } + storage, err := NewLocalStorage(cfg.DataDir) + if err != nil { + _ = store.Close() + return nil, err + } + if cfg.Seed { + if err := SeedIfEmpty(ctx, store, storage); err != nil { + _ = store.Close() + return nil, err + } + } + s := &Server{ + mux: http.NewServeMux(), + store: store, + storage: storage, + addr: cfg.Addr, + adminToken: strings.TrimSpace(cfg.AdminToken), + maxRequestBodyBytes: cfg.MaxRequestBodyBytes, + } + if s.addr == "" { + s.addr = ":8791" + } + if s.maxRequestBodyBytes <= 0 { + s.maxRequestBodyBytes = 2 << 20 + } + s.registerRoutes() + return s, nil +} + +func (s *Server) Close() error { + if s == nil { + return nil + } + return s.store.Close() +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +func (s *Server) StartWithContext(ctx context.Context) error { + server := &http.Server{ + Addr: s.addr, + Handler: s, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 5 * time.Minute, + IdleTimeout: 60 * time.Second, + } + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + return server.ListenAndServe() +} + +func (s *Server) registerRoutes() { + s.mux.HandleFunc("GET /v1/health", s.handleHealth) + s.mux.HandleFunc("GET /v1/control-plane", s.handleControlPlane) + s.mux.HandleFunc("GET /v1/sources", s.handleSources) + s.mux.HandleFunc("GET /v1/artifacts", s.handleListArtifacts) + s.mux.HandleFunc("GET /v1/artifacts/{id}", s.handleArtifactDetail) + s.mux.HandleFunc("GET /v1/artifacts/{id}/versions", s.handleArtifactVersions) + s.mux.HandleFunc("POST /v1/artifacts/{id}/resolve", s.handleResolveArtifact) + s.mux.HandleFunc("GET /v1/download/{artifact_id}/{version}", s.handleDownload) + s.mux.HandleFunc("POST /v1/search", s.handleSearch) + s.mux.HandleFunc("POST /v1/publish", s.handlePublish) + s.mux.HandleFunc("DELETE /v1/admin/artifacts/{id}", s.handleDeleteArtifact) + s.mux.HandleFunc("GET /v1/admin/tokens", s.handlePublisherTokens) + s.mux.HandleFunc("POST /v1/admin/tokens", s.handleCreatePublisherToken) + s.mux.HandleFunc("POST /v1/admin/tokens/{id}/revoke", s.handleRevokePublisherToken) + s.mux.HandleFunc("POST /v1/artifacts/{id}/quarantine", s.handleQuarantine) + s.mux.HandleFunc("POST /v1/artifacts/{id}/unquarantine", s.handleUnquarantine) + s.mux.HandleFunc("GET /v1/admin/audit", s.handleAdminAudit) + s.mux.HandleFunc("GET /v1/admin/downloads", s.handleAdminDownloads) +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + writeData(w, http.StatusOK, map[string]string{"status": "ok", "service": "anyclaw-registry"}, 0) +} + +func (s *Server) handleControlPlane(w http.ResponseWriter, r *http.Request) { + writeData(w, http.StatusOK, map[string]any{ + "protocol_version": defaultProtocolVersion, + "artifact_kinds": []ArtifactKind{ArtifactKindAgent, ArtifactKindSkill, ArtifactKindCLI}, + "routes": []string{ + "GET /v1/sources", + "GET /v1/artifacts", + "GET /v1/artifacts/{id}", + "GET /v1/artifacts/{id}/versions", + "POST /v1/artifacts/{id}/resolve", + "GET /v1/download/{artifact_id}/{version}", + "POST /v1/search", + "POST /v1/publish", + "DELETE /v1/admin/artifacts/{id}", + "GET /v1/admin/tokens", + "POST /v1/artifacts/{id}/quarantine", + "POST /v1/artifacts/{id}/unquarantine", + "POST /v1/admin/tokens/{id}/revoke", + "GET /v1/admin/audit", + "GET /v1/admin/downloads", + }, + "admin_auth": "bearer", + }, 0) +} + +func (s *Server) handleSources(w http.ResponseWriter, r *http.Request) { + writeData(w, http.StatusOK, map[string]any{ + "items": []map[string]any{ + { + "id": defaultRegistrySourceID, + "name": "AnyClaw Cloud", + "trust_level": "verified", + "kinds": []ArtifactKind{ArtifactKindAgent, ArtifactKindSkill, ArtifactKindCLI}, + }, + }, + }, 1) +} + +func (s *Server) handleListArtifacts(w http.ResponseWriter, r *http.Request) { + filter := filterFromQuery(r) + result, err := s.store.List(r.Context(), filter) + if err != nil { + writeError(w, http.StatusInternalServerError, "list_failed", "failed to list artifacts", err.Error()) + return + } + writeData(w, http.StatusOK, result, len(result.Items)) +} + +func (s *Server) handleArtifactDetail(w http.ResponseWriter, r *http.Request) { + artifact, err := s.store.Get(r.Context(), r.PathValue("id")) + if err != nil { + writeStoreError(w, err) + return + } + writeData(w, http.StatusOK, artifact, 1) +} + +func (s *Server) handleArtifactVersions(w http.ResponseWriter, r *http.Request) { + versions, err := s.store.Versions(r.Context(), r.PathValue("id")) + if err != nil { + writeStoreError(w, err) + return + } + writeData(w, http.StatusOK, VersionListResult{Items: versions, Total: len(versions)}, len(versions)) +} + +func (s *Server) handleResolveArtifact(w http.ResponseWriter, r *http.Request) { + var req ResolveRequest + if !s.decodeJSON(w, r, &req) { + return + } + artifact, version, err := s.store.Resolve(r.Context(), r.PathValue("id"), req) + if err != nil { + writeStoreError(w, err) + return + } + resolved := ResolvedArtifact{ + ArtifactID: artifact.ID, + Version: version.Version, + DownloadURL: absoluteURL(r, "/v1/download/"+url.PathEscape(artifact.ID)+"/"+url.PathEscape(version.Version)), + ChecksumSHA256: version.ChecksumSHA256, + Signature: version.Signature, + SizeBytes: version.SizeBytes, + ManifestURL: absoluteURL(r, "/v1/artifacts/"+url.PathEscape(artifact.ID)), + Compatibility: version.Compatibility, + Dependencies: artifact.Dependencies, + RiskLevel: artifact.RiskLevel, + TrustLevel: artifact.TrustLevel, + Permissions: version.Permissions, + Kind: artifact.Kind, + Name: artifact.Name, + } + writeData(w, http.StatusOK, resolved, 1) +} + +func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) { + artifactID := r.PathValue("artifact_id") + versionID := r.PathValue("version") + if _, err := s.store.Quarantine(r.Context(), artifactID); err == nil { + writeStoreError(w, ErrArtifactUnavailable) + return + } + version, err := s.store.Version(r.Context(), artifactID, versionID) + if err != nil { + writeStoreError(w, err) + return + } + if version.StorageKey == "" { + writeError(w, http.StatusNotFound, "package_not_found", "package not found", "") + return + } + file, err := s.storage.Open(version.StorageKey) + if err != nil { + writeError(w, http.StatusNotFound, "package_not_found", "package not found", err.Error()) + return + } + defer file.Close() + _ = s.store.RecordDownload(r.Context(), artifactID, versionID, r.RemoteAddr, r.UserAgent()) + + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s-%s.zip"`, artifactID, versionID)) + w.Header().Set("X-Checksum-SHA256", version.ChecksumSHA256) + if version.Signature != "" { + w.Header().Set("X-Artifact-Signature", version.Signature) + } + w.Header().Set("X-Artifact-ID", artifactID) + w.Header().Set("X-Artifact-Version", versionID) + w.Header().Set("Content-Length", strconv.FormatInt(version.SizeBytes, 10)) + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, file) +} + +func (s *Server) handleSearch(w http.ResponseWriter, r *http.Request) { + var filter SearchFilter + if !s.decodeJSON(w, r, &filter) { + return + } + result, err := s.store.List(r.Context(), filter) + if err != nil { + writeError(w, http.StatusInternalServerError, "search_failed", "failed to search artifacts", err.Error()) + return + } + writeData(w, http.StatusOK, result, len(result.Items)) +} + +func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) { + publisherID, ok := s.authorizePublisher(r) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthorized", "publisher token required", "") + return + } + var req PublishRequest + if !s.decodeJSON(w, r, &req) { + return + } + if strings.TrimSpace(req.Artifact.Publisher) == "" { + req.Artifact.Publisher = publisherID + } + if len(req.Versions) == 0 && strings.TrimSpace(req.Artifact.LatestVersion) != "" { + req.Versions = []ArtifactVersion{{ArtifactID: req.Artifact.ID, Version: req.Artifact.LatestVersion}} + } + for i := range req.Versions { + if req.Versions[i].ArtifactID == "" { + req.Versions[i].ArtifactID = req.Artifact.ID + } + info, err := s.storage.EnsurePackage(req.Artifact, req.Versions[i]) + if err != nil { + writeError(w, http.StatusBadRequest, "package_failed", "failed to prepare package", err.Error()) + return + } + req.Versions[i].StorageKey = info.StorageKey + if req.Versions[i].SizeBytes == 0 { + req.Versions[i].SizeBytes = info.SizeBytes + } + if req.Versions[i].ChecksumSHA256 == "" { + req.Versions[i].ChecksumSHA256 = info.ChecksumSHA256 + } + } + if err := s.store.UpsertArtifact(r.Context(), req.Artifact, req.Versions); err != nil { + writeError(w, http.StatusBadRequest, "publish_failed", "failed to publish artifact", err.Error()) + return + } + _ = s.store.AppendAudit(r.Context(), RegistryAuditEvent{Event: "artifact.published", Artifact: req.Artifact.ID, Detail: map[string]any{"publisher_id": publisherID, "versions": len(req.Versions)}}) + writeData(w, http.StatusOK, req.Artifact, 1) +} + +func (s *Server) handleCreatePublisherToken(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + var req struct { + PublisherID string `json:"publisher_id"` + } + if !s.decodeJSON(w, r, &req) { + return + } + token, err := s.store.CreatePublisherToken(r.Context(), req.PublisherID) + if err != nil { + writeError(w, http.StatusBadRequest, "token_failed", "failed to create publisher token", err.Error()) + return + } + writeData(w, http.StatusOK, token, 1) +} + +func (s *Server) handlePublisherTokens(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + result, err := s.store.PublisherTokens(r.Context(), parseInt(r.URL.Query().Get("limit"), 100)) + if err != nil { + writeError(w, http.StatusInternalServerError, "tokens_failed", "failed to list publisher tokens", err.Error()) + return + } + writeData(w, http.StatusOK, result, len(result.Items)) +} + +func (s *Server) handleRevokePublisherToken(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + record, err := s.store.RevokePublisherToken(r.Context(), r.PathValue("id")) + if err != nil { + writeStoreError(w, err) + return + } + writeData(w, http.StatusOK, record, 1) +} + +func (s *Server) handleDeleteArtifact(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + record, err := s.store.DeleteArtifact(r.Context(), r.PathValue("id")) + if err != nil { + writeStoreError(w, err) + return + } + writeData(w, http.StatusOK, record, 1) +} + +func (s *Server) handleQuarantine(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + var req struct { + Reason string `json:"reason"` + } + if !s.decodeJSON(w, r, &req) { + return + } + record, err := s.store.SetQuarantine(r.Context(), r.PathValue("id"), req.Reason) + if err != nil { + writeError(w, http.StatusBadRequest, "quarantine_failed", "failed to quarantine artifact", err.Error()) + return + } + writeData(w, http.StatusOK, record, 1) +} + +func (s *Server) handleUnquarantine(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + if err := s.store.ClearQuarantine(r.Context(), r.PathValue("id")); err != nil { + writeStoreError(w, err) + return + } + writeData(w, http.StatusOK, map[string]string{"status": "unquarantined"}, 1) +} + +func (s *Server) handleAdminAudit(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + result, err := s.store.AuditEvents(r.Context(), parseInt(r.URL.Query().Get("limit"), 100)) + if err != nil { + writeError(w, http.StatusInternalServerError, "audit_failed", "failed to list audit events", err.Error()) + return + } + writeData(w, http.StatusOK, result, len(result.Items)) +} + +func (s *Server) handleAdminDownloads(w http.ResponseWriter, r *http.Request) { + if !s.authorizeAdmin(r) { + writeError(w, http.StatusUnauthorized, "unauthorized", "admin token required", "") + return + } + result, err := s.store.DownloadStats(r.Context(), parseInt(r.URL.Query().Get("limit"), 100)) + if err != nil { + writeError(w, http.StatusInternalServerError, "downloads_failed", "failed to list download stats", err.Error()) + return + } + writeData(w, http.StatusOK, result, len(result.Items)) +} + +func (s *Server) decodeJSON(w http.ResponseWriter, r *http.Request, dst any) bool { + r.Body = http.MaxBytesReader(w, r.Body, s.maxRequestBodyBytes) + dec := json.NewDecoder(r.Body) + if err := dec.Decode(dst); err != nil { + writeError(w, http.StatusBadRequest, "invalid_json", "invalid request body", err.Error()) + return false + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + writeError(w, http.StatusBadRequest, "invalid_json", "invalid request body", "") + return false + } + return true +} + +func filterFromQuery(r *http.Request) SearchFilter { + q := r.URL.Query() + return SearchFilter{ + Kind: ArtifactKind(strings.TrimSpace(q.Get("kind"))), + Source: strings.TrimSpace(q.Get("source")), + Query: strings.TrimSpace(q.Get("q")), + Risk: strings.TrimSpace(q.Get("risk")), + Trust: strings.TrimSpace(q.Get("trust")), + Tag: strings.TrimSpace(q.Get("tag")), + Permission: strings.TrimSpace(q.Get("permission")), + Publisher: strings.TrimSpace(q.Get("publisher")), + OS: strings.TrimSpace(q.Get("os")), + Arch: strings.TrimSpace(q.Get("arch")), + Sort: strings.TrimSpace(q.Get("sort")), + Limit: parseInt(q.Get("limit"), 50), + Offset: parseInt(q.Get("offset"), 0), + } +} + +func parseInt(value string, fallback int) int { + if value == "" { + return fallback + } + n, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return n +} + +func (s *Server) authorizeAdmin(r *http.Request) bool { + if s == nil || s.adminToken == "" { + return true + } + return bearerToken(r) == s.adminToken +} + +func (s *Server) authorizePublisher(r *http.Request) (string, bool) { + token := bearerToken(r) + if token == "" { + return "", false + } + publisherID, ok, err := s.store.ValidatePublisherToken(r.Context(), token) + if err != nil || !ok { + return "", false + } + return publisherID, true +} + +func bearerToken(r *http.Request) string { + auth := strings.TrimSpace(r.Header.Get("Authorization")) + const prefix = "Bearer " + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "" + } + return strings.TrimSpace(auth[len(prefix):]) +} + +func writeStoreError(w http.ResponseWriter, err error) { + switch { + case errors.Is(err, ErrArtifactNotFound), errors.Is(err, ErrVersionNotFound): + writeError(w, http.StatusNotFound, "not_found", err.Error(), "") + case errors.Is(err, ErrNoCompatibleVersion): + writeError(w, http.StatusConflict, "no_compatible_version", err.Error(), "") + case errors.Is(err, ErrArtifactUnavailable): + writeError(w, http.StatusGone, "artifact_unavailable", err.Error(), "") + default: + writeError(w, http.StatusInternalServerError, "registry_error", "registry error", err.Error()) + } +} + +func writeData(w http.ResponseWriter, status int, data any, count int) { + writeJSON(w, status, map[string]any{ + "data": data, + "meta": ResponseMeta{ + ProtocolVersion: defaultProtocolVersion, + Count: count, + }, + }) +} + +func writeError(w http.ResponseWriter, status int, code, message, detail string) { + var payload ErrorResponse + payload.Error.Code = code + payload.Error.Message = message + payload.Error.Detail = detail + writeJSON(w, status, payload) +} + +func writeJSON(w http.ResponseWriter, status int, value any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(value) +} + +func absoluteURL(r *http.Request, path string) string { + scheme := r.Header.Get("X-Forwarded-Proto") + if scheme == "" { + scheme = "http" + if r.TLS != nil { + scheme = "https" + } + } + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + return scheme + "://" + host + path +} diff --git a/pkg/marketregistry/server_test.go b/pkg/marketregistry/server_test.go new file mode 100644 index 00000000..3e0cfe48 --- /dev/null +++ b/pkg/marketregistry/server_test.go @@ -0,0 +1,370 @@ +package marketregistry + +import ( + "archive/zip" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestServerSeededCatalogRoutes(t *testing.T) { + server := newTestServer(t) + + var list struct { + Data ListResult `json:"data"` + } + doJSON(t, server, http.MethodGet, "/v1/artifacts", nil, http.StatusOK, &list) + if list.Data.Total != 3 { + t.Fatalf("expected 3 seeded artifacts, got %d", list.Data.Total) + } + + byKind := map[ArtifactKind]bool{} + for _, item := range list.Data.Items { + byKind[item.Kind] = true + } + for _, kind := range []ArtifactKind{ArtifactKindAgent, ArtifactKindSkill, ArtifactKindCLI} { + if !byKind[kind] { + t.Fatalf("expected seeded kind %s in catalog", kind) + } + } + + var detail struct { + Data Artifact `json:"data"` + } + doJSON(t, server, http.MethodGet, "/v1/artifacts/cloud.skill.release-notes", nil, http.StatusOK, &detail) + if detail.Data.ID != "cloud.skill.release-notes" || detail.Data.Kind != ArtifactKindSkill { + t.Fatalf("unexpected detail artifact: %#v", detail.Data) + } + + var versions struct { + Data VersionListResult `json:"data"` + } + doJSON(t, server, http.MethodGet, "/v1/artifacts/cloud.skill.release-notes/versions", nil, http.StatusOK, &versions) + if versions.Data.Total != 1 { + t.Fatalf("expected one version, got %d", versions.Data.Total) + } + if versions.Data.Items[0].ChecksumSHA256 == "" || versions.Data.Items[0].SizeBytes == 0 { + t.Fatalf("expected seeded version checksum and size: %#v", versions.Data.Items[0]) + } +} + +func TestServerResolveAndDownload(t *testing.T) { + server := newTestServer(t) + + var resolved struct { + Data ResolvedArtifact `json:"data"` + } + doJSON(t, server, http.MethodPost, "/v1/artifacts/cloud.cli.repo-health/resolve", strings.NewReader(`{}`), http.StatusOK, &resolved) + if resolved.Data.ArtifactID != "cloud.cli.repo-health" { + t.Fatalf("unexpected resolved artifact: %#v", resolved.Data) + } + if resolved.Data.DownloadURL == "" || resolved.Data.ChecksumSHA256 == "" || resolved.Data.SizeBytes == 0 { + t.Fatalf("resolve response missing download metadata: %#v", resolved.Data) + } + + req := httptest.NewRequest(http.MethodGet, "/v1/download/cloud.cli.repo-health/1.0.0", nil) + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("download status = %d, body = %s", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("X-Checksum-SHA256"); got != resolved.Data.ChecksumSHA256 { + t.Fatalf("download checksum header = %q, want %q", got, resolved.Data.ChecksumSHA256) + } + sum := sha256.Sum256(rec.Body.Bytes()) + if got := hex.EncodeToString(sum[:]); got != resolved.Data.ChecksumSHA256 { + t.Fatalf("download body checksum = %q, want %q", got, resolved.Data.ChecksumSHA256) + } + assertZipContains(t, rec.Body.Bytes(), "anyclaw.artifact.json") +} + +func TestServerSearchAndErrors(t *testing.T) { + server := newTestServer(t) + + var search struct { + Data ListResult `json:"data"` + } + doJSON(t, server, http.MethodPost, "/v1/search", strings.NewReader(`{"kind":"skill","q":"release"}`), http.StatusOK, &search) + if search.Data.Total != 1 || search.Data.Items[0].ID != "cloud.skill.release-notes" { + t.Fatalf("unexpected search result: %#v", search.Data) + } + + var missing ErrorResponse + doJSON(t, server, http.MethodGet, "/v1/artifacts/does-not-exist", nil, http.StatusNotFound, &missing) + if missing.Error.Code != "not_found" { + t.Fatalf("unexpected error response: %#v", missing) + } +} + +func TestServerSearchFiltersCombine(t *testing.T) { + server := newTestServer(t) + + var list struct { + Data ListResult `json:"data"` + } + doJSON(t, server, http.MethodGet, "/v1/artifacts?kind=skill&q=release&risk=low&trust=verified&tag=writing&publisher=AnyClaw%20Labs&permission=fs.read&os=windows&arch=amd64&sort=name", nil, http.StatusOK, &list) + if list.Data.Total != 1 || list.Data.Items[0].ID != "cloud.skill.release-notes" { + t.Fatalf("unexpected combined filter result: %#v", list.Data) + } + + var empty struct { + Data ListResult `json:"data"` + } + doJSON(t, server, http.MethodGet, "/v1/artifacts?kind=skill&q=release&risk=high&trust=verified&tag=writing&publisher=AnyClaw%20Labs", nil, http.StatusOK, &empty) + if empty.Data.Total != 0 { + t.Fatalf("expected no results when one filter mismatches, got %#v", empty.Data) + } +} + +func TestServerAdminTokenPublishQuarantineAndStats(t *testing.T) { + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + Seed: true, + AdminToken: "admin-secret", + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = server.Close() }) + + var unauthorized ErrorResponse + doJSON(t, server, http.MethodGet, "/v1/admin/audit", nil, http.StatusUnauthorized, &unauthorized) + + var token struct { + Data PublisherToken `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/admin/tokens", strings.NewReader(`{"publisher_id":"AnyClaw Labs"}`), "admin-secret", http.StatusOK, &token) + if token.Data.Token == "" { + t.Fatalf("expected one-time publisher token: %#v", token.Data) + } + var tokens struct { + Data PublisherTokenList `json:"data"` + } + doJSONWithAuth(t, server, http.MethodGet, "/v1/admin/tokens", nil, "admin-secret", http.StatusOK, &tokens) + if tokens.Data.Total != 1 || tokens.Data.Items[0].Token != "" || tokens.Data.Items[0].ID != token.Data.ID { + t.Fatalf("unexpected publisher token list: %#v", tokens.Data) + } + + publishBody := `{"artifact":{"id":"cloud.skill.test-publish","kind":"skill","name":"Published Skill","summary":"Published from test","latest_version":"1.0.0","risk_level":"low","trust_level":"verified","permissions":["fs.read"],"compatibility":{"os":["windows"]},"tags":["publish"]},"versions":[{"version":"1.0.0","signature":"sig-test"}]}` + var published struct { + Data Artifact `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/publish", strings.NewReader(publishBody), token.Data.Token, http.StatusOK, &published) + if published.Data.ID != "cloud.skill.test-publish" { + t.Fatalf("unexpected published artifact: %#v", published.Data) + } + + var resolved struct { + Data ResolvedArtifact `json:"data"` + } + doJSON(t, server, http.MethodPost, "/v1/artifacts/cloud.skill.test-publish/resolve", strings.NewReader(`{}`), http.StatusOK, &resolved) + if resolved.Data.Signature != "sig-test" { + t.Fatalf("expected signature in resolve response: %#v", resolved.Data) + } + + req := httptest.NewRequest(http.MethodGet, "/v1/download/cloud.skill.test-publish/1.0.0", nil) + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("download status = %d body=%s", rec.Code, rec.Body.String()) + } + if rec.Header().Get("X-Artifact-Signature") != "sig-test" { + t.Fatalf("expected signature header, got %q", rec.Header().Get("X-Artifact-Signature")) + } + + var downloads struct { + Data DownloadStatsResult `json:"data"` + } + doJSONWithAuth(t, server, http.MethodGet, "/v1/admin/downloads", nil, "admin-secret", http.StatusOK, &downloads) + if downloads.Data.Total == 0 { + t.Fatal("expected download stats") + } + + var quarantine struct { + Data QuarantineRecord `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/artifacts/cloud.skill.test-publish/quarantine", strings.NewReader(`{"reason":"bad package"}`), "admin-secret", http.StatusOK, &quarantine) + if quarantine.Data.Reason != "bad package" { + t.Fatalf("unexpected quarantine: %#v", quarantine.Data) + } + doJSON(t, server, http.MethodPost, "/v1/artifacts/cloud.skill.test-publish/resolve", strings.NewReader(`{}`), http.StatusGone, &ErrorResponse{}) + doJSONWithAuth(t, server, http.MethodPost, "/v1/artifacts/cloud.skill.test-publish/unquarantine", strings.NewReader(`{}`), "admin-secret", http.StatusOK, &struct{}{}) + doJSON(t, server, http.MethodPost, "/v1/artifacts/cloud.skill.test-publish/resolve", strings.NewReader(`{}`), http.StatusOK, &resolved) + + var audit struct { + Data RegistryAuditList `json:"data"` + } + doJSONWithAuth(t, server, http.MethodGet, "/v1/admin/audit", nil, "admin-secret", http.StatusOK, &audit) + if audit.Data.Total == 0 { + t.Fatal("expected audit events") + } +} + +func TestServerRequireAdminToken(t *testing.T) { + _, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + RequireAdminToken: true, + }) + if err == nil || !strings.Contains(err.Error(), "admin token is required") { + t.Fatalf("expected missing admin token error, got %v", err) + } + + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + AdminToken: "admin-secret", + RequireAdminToken: true, + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = server.Close() }) + + var unauthorized ErrorResponse + doJSON(t, server, http.MethodGet, "/v1/admin/audit", nil, http.StatusUnauthorized, &unauthorized) +} + +func TestServerRevokePublisherToken(t *testing.T) { + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + AdminToken: "admin-secret", + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = server.Close() }) + + var token struct { + Data PublisherToken `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/admin/tokens", strings.NewReader(`{"publisher_id":"AnyClaw Labs"}`), "admin-secret", http.StatusOK, &token) + + var revoked struct { + Data PublisherTokenRevocation `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/admin/tokens/"+token.Data.ID+"/revoke", nil, "admin-secret", http.StatusOK, &revoked) + if revoked.Data.ID != token.Data.ID || revoked.Data.RevokedAt == "" { + t.Fatalf("unexpected revocation: %#v", revoked.Data) + } + + publishBody := `{"artifact":{"id":"cloud.skill.revoked-token","kind":"skill","name":"Revoked Token Skill","summary":"Should not publish","latest_version":"1.0.0","risk_level":"low","trust_level":"verified","permissions":["fs.read"],"compatibility":{"os":["windows"]}},"versions":[{"version":"1.0.0"}]}` + var unauthorized ErrorResponse + doJSONWithAuth(t, server, http.MethodPost, "/v1/publish", strings.NewReader(publishBody), token.Data.Token, http.StatusUnauthorized, &unauthorized) +} + +func TestServerAdminDeleteArtifact(t *testing.T) { + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + Seed: true, + AdminToken: "admin-secret", + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = server.Close() }) + + var unauthorized ErrorResponse + doJSON(t, server, http.MethodDelete, "/v1/admin/artifacts/cloud.skill.release-notes", nil, http.StatusUnauthorized, &unauthorized) + + var deleted struct { + Data ArtifactDeletion `json:"data"` + } + doJSONWithAuth(t, server, http.MethodDelete, "/v1/admin/artifacts/cloud.skill.release-notes", nil, "admin-secret", http.StatusOK, &deleted) + if deleted.Data.ArtifactID != "cloud.skill.release-notes" || deleted.Data.DeletedAt == "" { + t.Fatalf("unexpected deletion response: %#v", deleted.Data) + } + + var missing ErrorResponse + doJSON(t, server, http.MethodGet, "/v1/artifacts/cloud.skill.release-notes", nil, http.StatusNotFound, &missing) + + var audit struct { + Data RegistryAuditList `json:"data"` + } + doJSONWithAuth(t, server, http.MethodGet, "/v1/admin/audit", nil, "admin-secret", http.StatusOK, &audit) + found := false + for _, event := range audit.Data.Items { + if event.Event == "artifact.deleted" && event.Artifact == "cloud.skill.release-notes" { + found = true + break + } + } + if !found { + t.Fatalf("expected artifact.deleted audit event, got %#v", audit.Data.Items) + } +} + +func newTestServer(t *testing.T) *Server { + t.Helper() + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + Seed: true, + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := server.Close(); err != nil { + t.Fatal(err) + } + }) + return server +} + +func doJSON(t *testing.T, handler http.Handler, method, path string, body io.Reader, wantStatus int, dst any) { + t.Helper() + req := httptest.NewRequest(method, path, body) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != wantStatus { + t.Fatalf("%s %s status = %d, want %d, body = %s", method, path, rec.Code, wantStatus, rec.Body.String()) + } + if dst != nil { + if err := json.NewDecoder(rec.Body).Decode(dst); err != nil { + t.Fatalf("decode response: %v", err) + } + } +} + +func doJSONWithAuth(t *testing.T, handler http.Handler, method, path string, body io.Reader, token string, wantStatus int, dst any) { + t.Helper() + req := httptest.NewRequest(method, path, body) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != wantStatus { + t.Fatalf("%s %s status = %d, want %d, body = %s", method, path, rec.Code, wantStatus, rec.Body.String()) + } + if dst != nil { + if err := json.NewDecoder(rec.Body).Decode(dst); err != nil { + t.Fatalf("decode response: %v", err) + } + } +} + +func assertZipContains(t *testing.T, data []byte, name string) { + t.Helper() + reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + t.Fatalf("open zip: %v", err) + } + for _, file := range reader.File { + if file.Name == name { + return + } + } + t.Fatalf("zip did not contain %s", name) +} diff --git a/pkg/marketregistry/storage.go b/pkg/marketregistry/storage.go new file mode 100644 index 00000000..d4e4289f --- /dev/null +++ b/pkg/marketregistry/storage.go @@ -0,0 +1,169 @@ +package marketregistry + +import ( + "archive/zip" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +type PackageStorage interface { + EnsurePackage(artifact Artifact, version ArtifactVersion) (PackageInfo, error) + Open(storageKey string) (*os.File, error) +} + +type LocalStorage struct { + packagesDir string +} + +type PackageInfo struct { + StorageKey string + Path string + SizeBytes int64 + ChecksumSHA256 string +} + +func NewLocalStorage(dataDir string) (*LocalStorage, error) { + if strings.TrimSpace(dataDir) == "" { + dataDir = ".anyclaw-registry" + } + storage := &LocalStorage{packagesDir: filepath.Join(dataDir, "packages")} + if err := os.MkdirAll(storage.packagesDir, 0o755); err != nil { + return nil, err + } + return storage, nil +} + +func (s *LocalStorage) EnsurePackage(artifact Artifact, version ArtifactVersion) (PackageInfo, error) { + if s == nil { + return PackageInfo{}, fmt.Errorf("storage is not configured") + } + storageKey := filepath.ToSlash(filepath.Join(artifact.ID, version.Version, "artifact.zip")) + path := filepath.Join(s.packagesDir, artifact.ID, version.Version, "artifact.zip") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return PackageInfo{}, err + } + if _, err := os.Stat(path); errorsIsNotExist(err) { + if err := writePackage(path, artifact, version); err != nil { + return PackageInfo{}, err + } + } else if err != nil { + return PackageInfo{}, err + } + info, err := checksumFile(path) + if err != nil { + return PackageInfo{}, err + } + info.StorageKey = storageKey + info.Path = path + return info, nil +} + +func (s *LocalStorage) Open(storageKey string) (*os.File, error) { + if s == nil { + return nil, fmt.Errorf("storage is not configured") + } + clean := filepath.Clean(filepath.FromSlash(storageKey)) + path := filepath.Join(s.packagesDir, clean) + base, err := filepath.Abs(s.packagesDir) + if err != nil { + return nil, err + } + target, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if target != base && !strings.HasPrefix(target, base+string(os.PathSeparator)) { + return nil, fmt.Errorf("invalid storage key") + } + return os.Open(target) +} + +func writePackage(path string, artifact Artifact, version ArtifactVersion) error { + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + + zw := zip.NewWriter(file) + defer zw.Close() + + manifest := artifact + manifest.Version = version.Version + manifest.SizeBytes = 0 + manifest.ChecksumSHA256 = "" + if err := writeZipJSON(zw, "anyclaw.artifact.json", manifest); err != nil { + return err + } + if err := writeZipText(zw, "README.md", fmt.Sprintf("# %s\n\n%s\n", artifact.Name, artifact.Summary)); err != nil { + return err + } + + switch artifact.Kind { + case ArtifactKindAgent: + return writeZipJSON(zw, "agent/profile.json", map[string]any{ + "id": artifact.ID, + "name": artifact.Name, + "description": artifact.Summary, + }) + case ArtifactKindSkill: + return writeZipText(zw, "skill/SKILL.md", fmt.Sprintf("# %s\n\n%s\n", artifact.Name, artifact.DescriptionMD)) + case ArtifactKindCLI: + return writeZipJSON(zw, "cli/command.json", map[string]any{ + "id": artifact.ID, + "name": artifact.Name, + "command": artifact.ManifestSummary["command"], + }) + default: + return nil + } +} + +func writeZipJSON(zw *zip.Writer, name string, value any) error { + data, err := json.MarshalIndent(value, "", " ") + if err != nil { + return err + } + return writeZipBytes(zw, name, data) +} + +func writeZipText(zw *zip.Writer, name, text string) error { + return writeZipBytes(zw, name, []byte(text)) +} + +func writeZipBytes(zw *zip.Writer, name string, data []byte) error { + w, err := zw.Create(name) + if err != nil { + return err + } + _, err = w.Write(data) + return err +} + +func checksumFile(path string) (PackageInfo, error) { + file, err := os.Open(path) + if err != nil { + return PackageInfo{}, err + } + defer file.Close() + + hash := sha256.New() + size, err := io.Copy(hash, file) + if err != nil { + return PackageInfo{}, err + } + return PackageInfo{ + SizeBytes: size, + ChecksumSHA256: hex.EncodeToString(hash.Sum(nil)), + }, nil +} + +func errorsIsNotExist(err error) bool { + return err != nil && os.IsNotExist(err) +} diff --git a/pkg/marketregistry/store.go b/pkg/marketregistry/store.go new file mode 100644 index 00000000..736067d2 --- /dev/null +++ b/pkg/marketregistry/store.go @@ -0,0 +1,899 @@ +package marketregistry + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" + + _ "modernc.org/sqlite" +) + +var ( + ErrArtifactNotFound = errors.New("artifact not found") + ErrVersionNotFound = errors.New("artifact version not found") + ErrNoCompatibleVersion = errors.New("no compatible artifact version found") + ErrInvalidArtifactKind = errors.New("invalid artifact kind") + ErrArtifactUnavailable = errors.New("artifact unavailable") + defaultProtocolVersion = "1.0" + defaultRegistrySourceID = "anyclaw-cloud" +) + +type Store struct { + db *sql.DB +} + +func OpenStore(ctx context.Context, dataDir string) (*Store, error) { + return OpenStoreWithConfig(ctx, StoreConfig{DataDir: dataDir}) +} + +func OpenStoreWithConfig(ctx context.Context, cfg StoreConfig) (*Store, error) { + dataDir := cfg.DataDir + if strings.TrimSpace(dataDir) == "" { + dataDir = ".anyclaw-registry" + } + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Join(dataDir, "audit"), 0o755); err != nil { + return nil, err + } + driver := strings.TrimSpace(cfg.Driver) + if driver == "" { + driver = "sqlite" + } + dsn := strings.TrimSpace(cfg.DSN) + if dsn == "" { + if driver != "sqlite" { + return nil, fmt.Errorf("registry db dsn is required for driver %s", driver) + } + dsn = filepath.Join(dataDir, "registry.db") + } + db, err := sql.Open(driver, dsn) + if err != nil { + return nil, err + } + store := &Store{db: db} + if err := store.migrate(ctx); err != nil { + _ = db.Close() + return nil, err + } + return store, nil +} + +func (s *Store) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +func (s *Store) migrate(ctx context.Context) error { + statements := []string{ + `CREATE TABLE IF NOT EXISTS artifacts ( + id TEXT PRIMARY KEY, + kind TEXT NOT NULL, + name TEXT NOT NULL, + summary TEXT NOT NULL, + description_md TEXT NOT NULL DEFAULT '', + latest_version TEXT NOT NULL, + source TEXT NOT NULL, + publisher TEXT NOT NULL, + risk_level TEXT NOT NULL, + trust_level TEXT NOT NULL, + permissions_json TEXT NOT NULL, + compatibility_json TEXT NOT NULL, + dependencies_json TEXT NOT NULL, + icon_url TEXT NOT NULL DEFAULT '', + tags_json TEXT NOT NULL, + hit_signals_json TEXT NOT NULL, + score REAL NOT NULL DEFAULT 0, + updated_at TEXT NOT NULL, + manifest_summary_json TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS artifact_versions ( + artifact_id TEXT NOT NULL, + version TEXT NOT NULL, + released_at TEXT NOT NULL, + changelog_md TEXT NOT NULL DEFAULT '', + compatibility_json TEXT NOT NULL, + permissions_json TEXT NOT NULL, + permissions_diff_json TEXT NOT NULL, + size_bytes INTEGER NOT NULL DEFAULT 0, + checksum_sha256 TEXT NOT NULL DEFAULT '', + signature TEXT NOT NULL DEFAULT '', + storage_key TEXT NOT NULL DEFAULT '', + deprecated INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (artifact_id, version), + FOREIGN KEY (artifact_id) REFERENCES artifacts(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS publishers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + trust_level TEXT NOT NULL, + created_at TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS tokens ( + id TEXT PRIMARY KEY, + publisher_id TEXT NOT NULL, + token_hash TEXT NOT NULL, + created_at TEXT NOT NULL, + revoked_at TEXT NOT NULL DEFAULT '' + )`, + `CREATE TABLE IF NOT EXISTS downloads ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + artifact_id TEXT NOT NULL, + version TEXT NOT NULL, + remote_addr TEXT NOT NULL, + user_agent TEXT NOT NULL, + created_at TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS quarantine ( + artifact_id TEXT PRIMARY KEY, + reason TEXT NOT NULL, + created_at TEXT NOT NULL + )`, + `CREATE TABLE IF NOT EXISTS audit_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL, + artifact_id TEXT NOT NULL DEFAULT '', + version TEXT NOT NULL DEFAULT '', + detail_json TEXT NOT NULL, + created_at TEXT NOT NULL + )`, + } + for _, stmt := range statements { + if _, err := s.db.ExecContext(ctx, stmt); err != nil { + return err + } + } + _, _ = s.db.ExecContext(ctx, `ALTER TABLE artifact_versions ADD COLUMN signature TEXT NOT NULL DEFAULT ''`) + return nil +} + +func (s *Store) CountArtifacts(ctx context.Context) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM artifacts`).Scan(&count) + return count, err +} + +func (s *Store) DeleteArtifact(ctx context.Context, artifactID string) (ArtifactDeletion, error) { + artifactID = strings.TrimSpace(artifactID) + if artifactID == "" { + return ArtifactDeletion{}, fmt.Errorf("artifact_id is required") + } + if _, err := s.Get(ctx, artifactID); err != nil { + return ArtifactDeletion{}, err + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return ArtifactDeletion{}, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + if _, err = tx.ExecContext(ctx, `DELETE FROM artifact_versions WHERE artifact_id = ?`, artifactID); err != nil { + return ArtifactDeletion{}, err + } + if _, err = tx.ExecContext(ctx, `DELETE FROM quarantine WHERE artifact_id = ?`, artifactID); err != nil { + return ArtifactDeletion{}, err + } + result, err := tx.ExecContext(ctx, `DELETE FROM artifacts WHERE id = ?`, artifactID) + if err != nil { + return ArtifactDeletion{}, err + } + rows, err := result.RowsAffected() + if err != nil { + return ArtifactDeletion{}, err + } + if rows == 0 { + return ArtifactDeletion{}, ErrArtifactNotFound + } + record := ArtifactDeletion{ + ArtifactID: artifactID, + DeletedAt: time.Now().UTC().Format(time.RFC3339), + } + detail, err := encodeJSON(map[string]any{}) + if err != nil { + return ArtifactDeletion{}, err + } + if _, err = tx.ExecContext(ctx, `INSERT INTO audit_events (event_type, artifact_id, version, detail_json, created_at) VALUES (?, ?, ?, ?, ?)`, + "artifact.deleted", artifactID, "", detail, record.DeletedAt); err != nil { + return ArtifactDeletion{}, err + } + if err = tx.Commit(); err != nil { + return ArtifactDeletion{}, err + } + return record, nil +} + +func (s *Store) UpsertArtifact(ctx context.Context, artifact Artifact, versions []ArtifactVersion) error { + if err := validateArtifact(artifact); err != nil { + return err + } + if artifact.Source == "" { + artifact.Source = defaultRegistrySourceID + } + if artifact.UpdatedAt == "" { + artifact.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + permissions, err := encodeJSON(artifact.Permissions) + if err != nil { + return err + } + compatibility, err := encodeJSON(artifact.Compatibility) + if err != nil { + return err + } + dependencies, err := encodeJSON(artifact.Dependencies) + if err != nil { + return err + } + tags, err := encodeJSON(artifact.Tags) + if err != nil { + return err + } + hitSignals, err := encodeJSON(artifact.HitSignals) + if err != nil { + return err + } + manifestSummary, err := encodeJSON(artifact.ManifestSummary) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `INSERT INTO artifacts ( + id, kind, name, summary, description_md, latest_version, source, publisher, + risk_level, trust_level, permissions_json, compatibility_json, + dependencies_json, icon_url, tags_json, hit_signals_json, score, + updated_at, manifest_summary_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + kind = excluded.kind, + name = excluded.name, + summary = excluded.summary, + description_md = excluded.description_md, + latest_version = excluded.latest_version, + source = excluded.source, + publisher = excluded.publisher, + risk_level = excluded.risk_level, + trust_level = excluded.trust_level, + permissions_json = excluded.permissions_json, + compatibility_json = excluded.compatibility_json, + dependencies_json = excluded.dependencies_json, + icon_url = excluded.icon_url, + tags_json = excluded.tags_json, + hit_signals_json = excluded.hit_signals_json, + score = excluded.score, + updated_at = excluded.updated_at, + manifest_summary_json = excluded.manifest_summary_json`, + artifact.ID, artifact.Kind, artifact.Name, artifact.Summary, artifact.DescriptionMD, + artifact.LatestVersion, artifact.Source, artifact.Publisher, artifact.RiskLevel, + artifact.TrustLevel, permissions, compatibility, dependencies, artifact.IconURL, + tags, hitSignals, artifact.Score, artifact.UpdatedAt, manifestSummary) + if err != nil { + return err + } + + for _, version := range versions { + if version.ArtifactID == "" { + version.ArtifactID = artifact.ID + } + if version.ReleasedAt == "" { + version.ReleasedAt = artifact.UpdatedAt + } + if version.Compatibility.AnyClawMin == "" && len(version.Compatibility.OS) == 0 && len(version.Compatibility.Arch) == 0 { + version.Compatibility = artifact.Compatibility + } + if len(version.Permissions) == 0 { + version.Permissions = artifact.Permissions + } + compatibility, err := encodeJSON(version.Compatibility) + if err != nil { + return err + } + permissions, err := encodeJSON(version.Permissions) + if err != nil { + return err + } + permissionsDiff, err := encodeJSON(version.PermissionsDiff) + if err != nil { + return err + } + deprecated := 0 + if version.Deprecated { + deprecated = 1 + } + _, err = tx.ExecContext(ctx, `INSERT INTO artifact_versions ( + artifact_id, version, released_at, changelog_md, compatibility_json, + permissions_json, permissions_diff_json, size_bytes, checksum_sha256, + signature, storage_key, deprecated + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(artifact_id, version) DO UPDATE SET + released_at = excluded.released_at, + changelog_md = excluded.changelog_md, + compatibility_json = excluded.compatibility_json, + permissions_json = excluded.permissions_json, + permissions_diff_json = excluded.permissions_diff_json, + size_bytes = excluded.size_bytes, + checksum_sha256 = excluded.checksum_sha256, + signature = excluded.signature, + storage_key = excluded.storage_key, + deprecated = excluded.deprecated`, + version.ArtifactID, version.Version, version.ReleasedAt, version.ChangelogMD, + compatibility, permissions, permissionsDiff, version.SizeBytes, + version.ChecksumSHA256, version.Signature, version.StorageKey, deprecated) + if err != nil { + return err + } + } + err = tx.Commit() + return err +} + +func (s *Store) List(ctx context.Context, filter SearchFilter) (ListResult, error) { + if filter.Limit <= 0 { + filter.Limit = 50 + } + if filter.Offset < 0 { + filter.Offset = 0 + } + + rows, err := s.db.QueryContext(ctx, `SELECT + id, kind, name, summary, description_md, latest_version, source, publisher, + risk_level, trust_level, permissions_json, compatibility_json, dependencies_json, + icon_url, tags_json, hit_signals_json, score, updated_at, manifest_summary_json + FROM artifacts`) + if err != nil { + return ListResult{}, err + } + defer rows.Close() + + var all []Artifact + for rows.Next() { + artifact, err := scanArtifact(rows) + if err != nil { + return ListResult{}, err + } + if matchesArtifact(artifact, filter) { + all = append(all, artifact) + } + } + if err := rows.Err(); err != nil { + return ListResult{}, err + } + + sortArtifacts(all, filter.Sort) + + total := len(all) + if filter.Offset >= len(all) { + all = nil + } else { + all = all[filter.Offset:] + if len(all) > filter.Limit { + all = all[:filter.Limit] + } + } + return ListResult{ + Items: all, + Total: total, + Limit: filter.Limit, + Offset: filter.Offset, + }, nil +} + +func (s *Store) Get(ctx context.Context, id string) (Artifact, error) { + row := s.db.QueryRowContext(ctx, `SELECT + id, kind, name, summary, description_md, latest_version, source, publisher, + risk_level, trust_level, permissions_json, compatibility_json, dependencies_json, + icon_url, tags_json, hit_signals_json, score, updated_at, manifest_summary_json + FROM artifacts WHERE id = ?`, id) + artifact, err := scanArtifact(row) + if errors.Is(err, sql.ErrNoRows) { + return Artifact{}, ErrArtifactNotFound + } + return artifact, err +} + +func (s *Store) Versions(ctx context.Context, artifactID string) ([]ArtifactVersion, error) { + if _, err := s.Get(ctx, artifactID); err != nil { + return nil, err + } + rows, err := s.db.QueryContext(ctx, `SELECT + artifact_id, version, released_at, changelog_md, compatibility_json, + permissions_json, permissions_diff_json, size_bytes, checksum_sha256, + signature, storage_key, deprecated + FROM artifact_versions WHERE artifact_id = ? ORDER BY released_at DESC, version DESC`, artifactID) + if err != nil { + return nil, err + } + defer rows.Close() + + var versions []ArtifactVersion + for rows.Next() { + version, err := scanVersion(rows) + if err != nil { + return nil, err + } + versions = append(versions, version) + } + return versions, rows.Err() +} + +func (s *Store) Version(ctx context.Context, artifactID, version string) (ArtifactVersion, error) { + row := s.db.QueryRowContext(ctx, `SELECT + artifact_id, version, released_at, changelog_md, compatibility_json, + permissions_json, permissions_diff_json, size_bytes, checksum_sha256, + signature, storage_key, deprecated + FROM artifact_versions WHERE artifact_id = ? AND version = ?`, artifactID, version) + item, err := scanVersion(row) + if errors.Is(err, sql.ErrNoRows) { + return ArtifactVersion{}, ErrVersionNotFound + } + return item, err +} + +func (s *Store) Resolve(ctx context.Context, artifactID string, req ResolveRequest) (Artifact, ArtifactVersion, error) { + if _, err := s.Quarantine(ctx, artifactID); err == nil { + return Artifact{}, ArtifactVersion{}, ErrArtifactUnavailable + } + artifact, err := s.Get(ctx, artifactID) + if err != nil { + return Artifact{}, ArtifactVersion{}, err + } + versions, err := s.Versions(ctx, artifactID) + if err != nil { + return Artifact{}, ArtifactVersion{}, err + } + + want := strings.TrimSpace(req.VersionConstraint) + foundRequestedVersion := false + for _, version := range versions { + if want != "" && version.Version != want { + continue + } + if want != "" { + foundRequestedVersion = true + } + if !compatibleWithClient(version.Compatibility, req) { + continue + } + return artifact, version, nil + } + if want != "" && !foundRequestedVersion { + return Artifact{}, ArtifactVersion{}, ErrVersionNotFound + } + return Artifact{}, ArtifactVersion{}, ErrNoCompatibleVersion +} + +func (s *Store) RecordDownload(ctx context.Context, artifactID, version, remoteAddr, userAgent string) error { + _, err := s.db.ExecContext(ctx, `INSERT INTO downloads ( + artifact_id, version, remote_addr, user_agent, created_at + ) VALUES (?, ?, ?, ?, ?)`, artifactID, version, remoteAddr, userAgent, time.Now().UTC().Format(time.RFC3339)) + return err +} + +func (s *Store) DownloadStats(ctx context.Context, limit int) (DownloadStatsResult, error) { + if limit <= 0 { + limit = 100 + } + rows, err := s.db.QueryContext(ctx, `SELECT artifact_id, version, COUNT(*), MAX(created_at) + FROM downloads GROUP BY artifact_id, version ORDER BY COUNT(*) DESC, MAX(created_at) DESC LIMIT ?`, limit) + if err != nil { + return DownloadStatsResult{}, err + } + defer rows.Close() + var items []DownloadStat + for rows.Next() { + var item DownloadStat + if err := rows.Scan(&item.ArtifactID, &item.Version, &item.Count, &item.LastAt); err != nil { + return DownloadStatsResult{}, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return DownloadStatsResult{}, err + } + return DownloadStatsResult{Items: items, Total: len(items)}, nil +} + +func (s *Store) Quarantine(ctx context.Context, artifactID string) (QuarantineRecord, error) { + var record QuarantineRecord + err := s.db.QueryRowContext(ctx, `SELECT artifact_id, reason, created_at FROM quarantine WHERE artifact_id = ?`, strings.TrimSpace(artifactID)). + Scan(&record.ArtifactID, &record.Reason, &record.CreatedAt) + if errors.Is(err, sql.ErrNoRows) { + return QuarantineRecord{}, ErrArtifactNotFound + } + return record, err +} + +func (s *Store) SetQuarantine(ctx context.Context, artifactID, reason string) (QuarantineRecord, error) { + record := QuarantineRecord{ + ArtifactID: strings.TrimSpace(artifactID), + Reason: strings.TrimSpace(reason), + CreatedAt: time.Now().UTC().Format(time.RFC3339), + } + if record.ArtifactID == "" { + return QuarantineRecord{}, fmt.Errorf("artifact_id is required") + } + if record.Reason == "" { + record.Reason = "quarantined by administrator" + } + if _, err := s.db.ExecContext(ctx, `INSERT INTO quarantine (artifact_id, reason, created_at) + VALUES (?, ?, ?) ON CONFLICT(artifact_id) DO UPDATE SET reason = excluded.reason, created_at = excluded.created_at`, + record.ArtifactID, record.Reason, record.CreatedAt); err != nil { + return QuarantineRecord{}, err + } + _ = s.AppendAudit(ctx, RegistryAuditEvent{Event: "artifact.quarantined", Artifact: record.ArtifactID, Detail: map[string]any{"reason": record.Reason}}) + return record, nil +} + +func (s *Store) ClearQuarantine(ctx context.Context, artifactID string) error { + artifactID = strings.TrimSpace(artifactID) + if artifactID == "" { + return fmt.Errorf("artifact_id is required") + } + result, err := s.db.ExecContext(ctx, `DELETE FROM quarantine WHERE artifact_id = ?`, artifactID) + if err != nil { + return err + } + if rows, _ := result.RowsAffected(); rows == 0 { + return ErrArtifactNotFound + } + _ = s.AppendAudit(ctx, RegistryAuditEvent{Event: "artifact.unquarantined", Artifact: artifactID}) + return nil +} + +func (s *Store) CreatePublisherToken(ctx context.Context, publisherID string) (PublisherToken, error) { + publisherID = strings.TrimSpace(publisherID) + if publisherID == "" { + return PublisherToken{}, fmt.Errorf("publisher_id is required") + } + raw := make([]byte, 24) + if _, err := rand.Read(raw); err != nil { + return PublisherToken{}, err + } + token := "acr_" + hex.EncodeToString(raw) + now := time.Now().UTC().Format(time.RFC3339) + hash := sha256.Sum256([]byte(token)) + item := PublisherToken{ + ID: "token-" + time.Now().UTC().Format("20060102150405.000000000"), + PublisherID: publisherID, + Token: token, + CreatedAt: now, + } + _, err := s.db.ExecContext(ctx, `INSERT INTO tokens (id, publisher_id, token_hash, created_at) VALUES (?, ?, ?, ?)`, + item.ID, item.PublisherID, hex.EncodeToString(hash[:]), item.CreatedAt) + if err != nil { + return PublisherToken{}, err + } + _ = s.AppendAudit(ctx, RegistryAuditEvent{Event: "publisher_token.created", Detail: map[string]any{"publisher_id": publisherID, "token_id": item.ID}}) + return item, nil +} + +func (s *Store) ValidatePublisherToken(ctx context.Context, token string) (string, bool, error) { + hash := sha256.Sum256([]byte(strings.TrimSpace(token))) + var publisherID string + err := s.db.QueryRowContext(ctx, `SELECT publisher_id FROM tokens WHERE token_hash = ? AND revoked_at = ''`, hex.EncodeToString(hash[:])).Scan(&publisherID) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + if err != nil { + return "", false, err + } + return publisherID, true, nil +} + +func (s *Store) PublisherTokens(ctx context.Context, limit int) (PublisherTokenList, error) { + if limit <= 0 { + limit = 100 + } + rows, err := s.db.QueryContext(ctx, `SELECT id, publisher_id, created_at, revoked_at FROM tokens ORDER BY created_at DESC LIMIT ?`, limit) + if err != nil { + return PublisherTokenList{}, err + } + defer rows.Close() + var items []PublisherToken + for rows.Next() { + var item PublisherToken + if err := rows.Scan(&item.ID, &item.PublisherID, &item.CreatedAt, &item.RevokedAt); err != nil { + return PublisherTokenList{}, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return PublisherTokenList{}, err + } + return PublisherTokenList{Items: items, Total: len(items)}, nil +} + +func (s *Store) RevokePublisherToken(ctx context.Context, tokenID string) (PublisherTokenRevocation, error) { + tokenID = strings.TrimSpace(tokenID) + if tokenID == "" { + return PublisherTokenRevocation{}, fmt.Errorf("token id is required") + } + now := time.Now().UTC().Format(time.RFC3339) + result, err := s.db.ExecContext(ctx, `UPDATE tokens SET revoked_at = ? WHERE id = ? AND revoked_at = ''`, now, tokenID) + if err != nil { + return PublisherTokenRevocation{}, err + } + affected, err := result.RowsAffected() + if err != nil { + return PublisherTokenRevocation{}, err + } + if affected == 0 { + return PublisherTokenRevocation{}, ErrArtifactNotFound + } + var publisherID string + if err := s.db.QueryRowContext(ctx, `SELECT publisher_id FROM tokens WHERE id = ?`, tokenID).Scan(&publisherID); err != nil { + return PublisherTokenRevocation{}, err + } + record := PublisherTokenRevocation{ID: tokenID, PublisherID: publisherID, RevokedAt: now} + _ = s.AppendAudit(ctx, RegistryAuditEvent{Event: "publisher_token.revoked", Detail: map[string]any{"publisher_id": publisherID, "token_id": tokenID}}) + return record, nil +} + +func (s *Store) AppendAudit(ctx context.Context, event RegistryAuditEvent) error { + if event.Created == "" { + event.Created = time.Now().UTC().Format(time.RFC3339) + } + detail, err := encodeJSON(event.Detail) + if err != nil { + return err + } + _, err = s.db.ExecContext(ctx, `INSERT INTO audit_events (event_type, artifact_id, version, detail_json, created_at) VALUES (?, ?, ?, ?, ?)`, + event.Event, event.Artifact, event.Version, detail, event.Created) + return err +} + +func (s *Store) AuditEvents(ctx context.Context, limit int) (RegistryAuditList, error) { + if limit <= 0 { + limit = 100 + } + rows, err := s.db.QueryContext(ctx, `SELECT id, event_type, artifact_id, version, detail_json, created_at FROM audit_events ORDER BY id DESC LIMIT ?`, limit) + if err != nil { + return RegistryAuditList{}, err + } + defer rows.Close() + var items []RegistryAuditEvent + for rows.Next() { + var item RegistryAuditEvent + var detail string + if err := rows.Scan(&item.ID, &item.Event, &item.Artifact, &item.Version, &detail, &item.Created); err != nil { + return RegistryAuditList{}, err + } + if err := decodeJSON(detail, &item.Detail); err != nil { + return RegistryAuditList{}, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return RegistryAuditList{}, err + } + return RegistryAuditList{Items: items, Total: len(items)}, nil +} + +type scanner interface { + Scan(dest ...any) error +} + +func scanArtifact(row scanner) (Artifact, error) { + var artifact Artifact + var permissions, compatibility, dependencies, tags, hitSignals, manifestSummary string + err := row.Scan( + &artifact.ID, &artifact.Kind, &artifact.Name, &artifact.Summary, + &artifact.DescriptionMD, &artifact.LatestVersion, &artifact.Source, + &artifact.Publisher, &artifact.RiskLevel, &artifact.TrustLevel, + &permissions, &compatibility, &dependencies, &artifact.IconURL, + &tags, &hitSignals, &artifact.Score, &artifact.UpdatedAt, + &manifestSummary, + ) + if err != nil { + return Artifact{}, err + } + artifact.Version = artifact.LatestVersion + if err := decodeJSON(permissions, &artifact.Permissions); err != nil { + return Artifact{}, err + } + if err := decodeJSON(compatibility, &artifact.Compatibility); err != nil { + return Artifact{}, err + } + if err := decodeJSON(dependencies, &artifact.Dependencies); err != nil { + return Artifact{}, err + } + if err := decodeJSON(tags, &artifact.Tags); err != nil { + return Artifact{}, err + } + if err := decodeJSON(hitSignals, &artifact.HitSignals); err != nil { + return Artifact{}, err + } + if err := decodeJSON(manifestSummary, &artifact.ManifestSummary); err != nil { + return Artifact{}, err + } + return artifact, nil +} + +func scanVersion(row scanner) (ArtifactVersion, error) { + var version ArtifactVersion + var compatibility, permissions, permissionsDiff string + var deprecated int + err := row.Scan( + &version.ArtifactID, &version.Version, &version.ReleasedAt, + &version.ChangelogMD, &compatibility, &permissions, &permissionsDiff, + &version.SizeBytes, &version.ChecksumSHA256, &version.Signature, + &version.StorageKey, &deprecated, + ) + if err != nil { + return ArtifactVersion{}, err + } + version.Deprecated = deprecated != 0 + if err := decodeJSON(compatibility, &version.Compatibility); err != nil { + return ArtifactVersion{}, err + } + if err := decodeJSON(permissions, &version.Permissions); err != nil { + return ArtifactVersion{}, err + } + if err := decodeJSON(permissionsDiff, &version.PermissionsDiff); err != nil { + return ArtifactVersion{}, err + } + return version, nil +} + +func matchesArtifact(artifact Artifact, filter SearchFilter) bool { + if filter.Kind != "" && artifact.Kind != filter.Kind { + return false + } + if filter.Source != "" && !strings.EqualFold(artifact.Source, filter.Source) { + return false + } + if filter.Risk != "" && !strings.EqualFold(artifact.RiskLevel, filter.Risk) { + return false + } + if filter.Trust != "" && !strings.EqualFold(artifact.TrustLevel, filter.Trust) { + return false + } + if filter.Tag != "" && !containsFold(artifact.Tags, filter.Tag) { + return false + } + if filter.Permission != "" && !containsFold(artifact.Permissions, filter.Permission) { + return false + } + if filter.Publisher != "" && !strings.Contains(strings.ToLower(artifact.Publisher), strings.ToLower(strings.TrimSpace(filter.Publisher))) { + return false + } + if filter.OS != "" && len(artifact.Compatibility.OS) > 0 && !containsFold(artifact.Compatibility.OS, filter.OS) { + return false + } + if filter.Arch != "" && len(artifact.Compatibility.Arch) > 0 && !containsFold(artifact.Compatibility.Arch, filter.Arch) { + return false + } + query := strings.ToLower(strings.TrimSpace(filter.Query)) + if query == "" { + return true + } + fields := []string{ + artifact.ID, + artifact.Name, + artifact.Summary, + artifact.DescriptionMD, + artifact.Publisher, + strings.Join(artifact.Tags, " "), + strings.Join(artifact.HitSignals, " "), + } + return strings.Contains(strings.ToLower(strings.Join(fields, " ")), query) +} + +func sortArtifacts(items []Artifact, mode string) { + sortMode := strings.ToLower(strings.TrimSpace(mode)) + sort.SliceStable(items, func(i, j int) bool { + switch sortMode { + case "updated", "updated_desc": + if items[i].UpdatedAt == items[j].UpdatedAt { + return fallbackArtifactLess(items[i], items[j]) + } + return items[i].UpdatedAt > items[j].UpdatedAt + case "name", "name_asc": + left := strings.ToLower(items[i].Name) + right := strings.ToLower(items[j].Name) + if left == right { + return fallbackArtifactLess(items[i], items[j]) + } + return left < right + default: + if items[i].Score == items[j].Score { + if items[i].UpdatedAt == items[j].UpdatedAt { + return fallbackArtifactLess(items[i], items[j]) + } + return items[i].UpdatedAt > items[j].UpdatedAt + } + return items[i].Score > items[j].Score + } + }) +} + +func fallbackArtifactLess(left, right Artifact) bool { + if left.Kind != right.Kind { + return left.Kind < right.Kind + } + return strings.ToLower(left.ID) < strings.ToLower(right.ID) +} + +func containsFold(values []string, want string) bool { + want = strings.ToLower(strings.TrimSpace(want)) + if want == "" { + return true + } + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), want) { + return true + } + } + return false +} + +func compatibleWithClient(compatibility Compatibility, req ResolveRequest) bool { + if req.ClientEnv.OS != "" && len(compatibility.OS) > 0 && !containsString(compatibility.OS, req.ClientEnv.OS) { + return false + } + if req.ClientEnv.Arch != "" && len(compatibility.Arch) > 0 && !containsString(compatibility.Arch, req.ClientEnv.Arch) { + return false + } + return true +} + +func validateArtifact(artifact Artifact) error { + if artifact.ID == "" || artifact.Name == "" || artifact.LatestVersion == "" { + return fmt.Errorf("artifact id, name, and latest_version are required") + } + switch artifact.Kind { + case ArtifactKindAgent, ArtifactKindSkill, ArtifactKindCLI: + return nil + default: + return ErrInvalidArtifactKind + } +} + +func encodeJSON(v any) (string, error) { + data, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(data), nil +} + +func decodeJSON(data string, dst any) error { + if strings.TrimSpace(data) == "" { + data = "null" + } + return json.Unmarshal([]byte(data), dst) +} + +func containsString(items []string, item string) bool { + for _, current := range items { + if current == item { + return true + } + } + return false +} diff --git a/pkg/marketregistry/types.go b/pkg/marketregistry/types.go new file mode 100644 index 00000000..4626b939 --- /dev/null +++ b/pkg/marketregistry/types.go @@ -0,0 +1,194 @@ +package marketregistry + +type ArtifactKind string + +const ( + ArtifactKindAgent ArtifactKind = "agent" + ArtifactKindSkill ArtifactKind = "skill" + ArtifactKindCLI ArtifactKind = "cli" +) + +type Artifact struct { + ID string `json:"id"` + Kind ArtifactKind `json:"kind"` + Name string `json:"name"` + Summary string `json:"summary"` + DescriptionMD string `json:"description_md,omitempty"` + Version string `json:"version"` + LatestVersion string `json:"latest_version"` + Source string `json:"source"` + Publisher string `json:"publisher"` + RiskLevel string `json:"risk_level"` + TrustLevel string `json:"trust_level"` + Permissions []string `json:"permissions"` + Compatibility Compatibility `json:"compatibility"` + Dependencies []Dependency `json:"dependencies,omitempty"` + SizeBytes int64 `json:"size_bytes,omitempty"` + ChecksumSHA256 string `json:"checksum_sha256,omitempty"` + IconURL string `json:"icon_url,omitempty"` + Tags []string `json:"tags,omitempty"` + HitSignals []string `json:"hit_signals,omitempty"` + Score float64 `json:"score,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + ManifestSummary map[string]string `json:"manifest_summary,omitempty"` +} + +type Compatibility struct { + AnyClawMin string `json:"anyclaw_min,omitempty"` + OS []string `json:"os,omitempty"` + Arch []string `json:"arch,omitempty"` +} + +type Dependency struct { + ID string `json:"id"` + VersionRange string `json:"version_range,omitempty"` +} + +type ArtifactVersion struct { + ArtifactID string `json:"artifact_id,omitempty"` + Version string `json:"version"` + ReleasedAt string `json:"released_at,omitempty"` + ChangelogMD string `json:"changelog_md,omitempty"` + Compatibility Compatibility `json:"compatibility,omitempty"` + Permissions []string `json:"permissions,omitempty"` + PermissionsDiff []string `json:"permissions_diff,omitempty"` + SizeBytes int64 `json:"size_bytes,omitempty"` + ChecksumSHA256 string `json:"checksum_sha256,omitempty"` + Signature string `json:"signature,omitempty"` + StorageKey string `json:"-"` + Deprecated bool `json:"deprecated,omitempty"` +} + +type ResolveRequest struct { + VersionConstraint string `json:"version_constraint,omitempty"` + ClientEnv struct { + AnyClawVersion string `json:"anyclaw_version,omitempty"` + OS string `json:"os,omitempty"` + Arch string `json:"arch,omitempty"` + } `json:"client_env,omitempty"` +} + +type ResolvedArtifact struct { + ArtifactID string `json:"artifact_id"` + Version string `json:"version"` + DownloadURL string `json:"download_url"` + ChecksumSHA256 string `json:"checksum_sha256"` + Signature string `json:"signature,omitempty"` + SizeBytes int64 `json:"size_bytes"` + ManifestURL string `json:"manifest_url,omitempty"` + Compatibility Compatibility `json:"compatibility"` + Dependencies []Dependency `json:"dependencies,omitempty"` + RiskLevel string `json:"risk_level"` + TrustLevel string `json:"trust_level"` + Permissions []string `json:"permissions"` + Kind ArtifactKind `json:"kind"` + Name string `json:"name"` +} + +type SearchFilter struct { + Kind ArtifactKind `json:"kind,omitempty"` + Source string `json:"source,omitempty"` + Query string `json:"q,omitempty"` + Risk string `json:"risk,omitempty"` + Trust string `json:"trust,omitempty"` + Tag string `json:"tag,omitempty"` + Permission string `json:"permission,omitempty"` + Publisher string `json:"publisher,omitempty"` + OS string `json:"os,omitempty"` + Arch string `json:"arch,omitempty"` + Sort string `json:"sort,omitempty"` + Limit int `json:"limit,omitempty"` + Offset int `json:"offset,omitempty"` +} + +type ListResult struct { + Items []Artifact `json:"items"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +type VersionListResult struct { + Items []ArtifactVersion `json:"items"` + Total int `json:"total"` +} + +type ResponseMeta struct { + ProtocolVersion string `json:"protocol_version,omitempty"` + Count int `json:"count,omitempty"` +} + +type ErrorResponse struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + Detail string `json:"detail,omitempty"` + } `json:"error"` +} + +type StoreConfig struct { + DataDir string + Driver string + DSN string +} + +type RegistryAuditEvent struct { + ID int64 `json:"id,omitempty"` + Event string `json:"event_type"` + Artifact string `json:"artifact_id,omitempty"` + Version string `json:"version,omitempty"` + Detail map[string]any `json:"detail,omitempty"` + Created string `json:"created_at,omitempty"` +} + +type RegistryAuditList struct { + Items []RegistryAuditEvent `json:"items"` + Total int `json:"total"` +} + +type DownloadStat struct { + ArtifactID string `json:"artifact_id"` + Version string `json:"version"` + Count int `json:"count"` + LastAt string `json:"last_at,omitempty"` +} + +type DownloadStatsResult struct { + Items []DownloadStat `json:"items"` + Total int `json:"total"` +} + +type QuarantineRecord struct { + ArtifactID string `json:"artifact_id"` + Reason string `json:"reason"` + CreatedAt string `json:"created_at"` +} + +type ArtifactDeletion struct { + ArtifactID string `json:"artifact_id"` + DeletedAt string `json:"deleted_at"` +} + +type PublisherToken struct { + ID string `json:"id"` + PublisherID string `json:"publisher_id"` + Token string `json:"token,omitempty"` + CreatedAt string `json:"created_at"` + RevokedAt string `json:"revoked_at,omitempty"` +} + +type PublisherTokenRevocation struct { + ID string `json:"id"` + PublisherID string `json:"publisher_id"` + RevokedAt string `json:"revoked_at"` +} + +type PublisherTokenList struct { + Items []PublisherToken `json:"items"` + Total int `json:"total"` +} + +type PublishRequest struct { + Artifact Artifact `json:"artifact"` + Versions []ArtifactVersion `json:"versions,omitempty"` +} From 17b3c9eeee400e30274b2cbd43aeb2f7cd7af272 Mon Sep 17 00:00:00 2001 From: h1177h <2928932863@qq.com> Date: Fri, 8 May 2026 20:55:39 +0800 Subject: [PATCH 2/2] fix: harden marketplace registry auth and storage --- cmd/anyclaw-registry/main.go | 31 +++++----------------- cmd/anyclaw-registry/main_test.go | 26 ------------------- pkg/marketregistry/server.go | 8 ++++-- pkg/marketregistry/server_test.go | 41 ++++++++++++++++++++++++------ pkg/marketregistry/storage.go | 23 +++++++++++++++-- pkg/marketregistry/storage_test.go | 33 ++++++++++++++++++++++++ 6 files changed, 100 insertions(+), 62 deletions(-) create mode 100644 pkg/marketregistry/storage_test.go diff --git a/cmd/anyclaw-registry/main.go b/cmd/anyclaw-registry/main.go index cc252a91..3ab117bb 100644 --- a/cmd/anyclaw-registry/main.go +++ b/cmd/anyclaw-registry/main.go @@ -42,7 +42,6 @@ func serve(args []string) error { dbDriver := fs.String("db-driver", "sqlite", "database/sql driver name") dbDSN := fs.String("db-dsn", "", "database DSN; defaults to /registry.db for sqlite") adminToken := fs.String("admin-token", os.Getenv("ANYCLAW_REGISTRY_ADMIN_TOKEN"), "admin bearer token; defaults to ANYCLAW_REGISTRY_ADMIN_TOKEN") - requireAdminToken := fs.Bool("require-admin-token", envBool("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", true), "fail startup when admin token is empty") seed := fs.Bool("seed", true, "seed fixture artifacts when the registry is empty") if err := fs.Parse(args); err != nil { return err @@ -52,13 +51,12 @@ func serve(args []string) error { defer stop() server, err := marketregistry.NewServer(ctx, marketregistry.ServerConfig{ - Addr: *addr, - DataDir: *dataDir, - DBDriver: *dbDriver, - DBDSN: *dbDSN, - AdminToken: *adminToken, - RequireAdminToken: *requireAdminToken, - Seed: *seed, + Addr: *addr, + DataDir: *dataDir, + DBDriver: *dbDriver, + DBDSN: *dbDSN, + AdminToken: *adminToken, + Seed: *seed, }) if err != nil { return err @@ -74,20 +72,5 @@ func serve(args []string) error { } func printUsage() { - fmt.Println("Usage: anyclaw-registry serve [--addr :8791] [--data-dir .anyclaw-registry] [--db-driver sqlite] [--db-dsn path-or-url] [--admin-token token] [--require-admin-token=true] [--seed=true]") -} - -func envBool(name string, fallback bool) bool { - value := os.Getenv(name) - if value == "" { - return fallback - } - switch value { - case "1", "true", "TRUE", "True", "yes", "YES", "on", "ON": - return true - case "0", "false", "FALSE", "False", "no", "NO", "off", "OFF": - return false - default: - return fallback - } + fmt.Println("Usage: anyclaw-registry serve [--addr :8791] [--data-dir .anyclaw-registry] [--db-driver sqlite] [--db-dsn path-or-url] [--admin-token token] [--seed=true]") } diff --git a/cmd/anyclaw-registry/main_test.go b/cmd/anyclaw-registry/main_test.go index 04d2f51a..43d25dd3 100644 --- a/cmd/anyclaw-registry/main_test.go +++ b/cmd/anyclaw-registry/main_test.go @@ -10,7 +10,6 @@ import ( func TestRunDefaultsToServeAndRequiresAdminToken(t *testing.T) { t.Setenv("ANYCLAW_REGISTRY_ADMIN_TOKEN", "") - t.Setenv("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", "") err := run([]string{"serve", "--data-dir", t.TempDir(), "--seed=false"}) if err == nil || !strings.Contains(err.Error(), "admin token is required") { @@ -18,16 +17,6 @@ func TestRunDefaultsToServeAndRequiresAdminToken(t *testing.T) { } } -func TestRunAllowsExplicitLocalRegistryWithoutAdminToken(t *testing.T) { - t.Setenv("ANYCLAW_REGISTRY_ADMIN_TOKEN", "") - t.Setenv("ANYCLAW_REGISTRY_REQUIRE_ADMIN_TOKEN", "") - - err := run([]string{"serve", "--addr", "127.0.0.1:bad", "--data-dir", t.TempDir(), "--seed=false", "--require-admin-token=false"}) - if err == nil || strings.Contains(err.Error(), "admin token is required") { - t.Fatalf("expected listen error after explicit insecure opt-out, got %v", err) - } -} - func TestRunHelpAndUnknownCommand(t *testing.T) { out := captureStdout(t, func() { if err := run([]string{"help"}); err != nil { @@ -42,21 +31,6 @@ func TestRunHelpAndUnknownCommand(t *testing.T) { } } -func TestEnvBoolParsesKnownValuesAndFallback(t *testing.T) { - t.Setenv("BOOL_VALUE", "yes") - if !envBool("BOOL_VALUE", false) { - t.Fatal("expected yes to parse true") - } - t.Setenv("BOOL_VALUE", "OFF") - if envBool("BOOL_VALUE", true) { - t.Fatal("expected OFF to parse false") - } - t.Setenv("BOOL_VALUE", "not-bool") - if !envBool("BOOL_VALUE", true) { - t.Fatal("expected invalid value to use fallback") - } -} - func captureStdout(t *testing.T, fn func()) string { t.Helper() original := os.Stdout diff --git a/pkg/marketregistry/server.go b/pkg/marketregistry/server.go index 79f57532..a69d6fb3 100644 --- a/pkg/marketregistry/server.go +++ b/pkg/marketregistry/server.go @@ -34,7 +34,7 @@ type Server struct { } func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { - if cfg.RequireAdminToken && strings.TrimSpace(cfg.AdminToken) == "" { + if strings.TrimSpace(cfg.AdminToken) == "" { return nil, fmt.Errorf("admin token is required") } store, err := OpenStoreWithConfig(ctx, StoreConfig{DataDir: cfg.DataDir, Driver: cfg.DBDriver, DSN: cfg.DBDSN}) @@ -282,6 +282,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(req.Artifact.Publisher) == "" { req.Artifact.Publisher = publisherID } + if req.Artifact.Publisher != publisherID { + writeError(w, http.StatusForbidden, "publisher_mismatch", "publisher token cannot publish for another publisher", "") + return + } if len(req.Versions) == 0 && strings.TrimSpace(req.Artifact.LatestVersion) != "" { req.Versions = []ArtifactVersion{{ArtifactID: req.Artifact.ID, Version: req.Artifact.LatestVersion}} } @@ -471,7 +475,7 @@ func parseInt(value string, fallback int) int { func (s *Server) authorizeAdmin(r *http.Request) bool { if s == nil || s.adminToken == "" { - return true + return false } return bearerToken(r) == s.adminToken } diff --git a/pkg/marketregistry/server_test.go b/pkg/marketregistry/server_test.go index 3e0cfe48..29e00d27 100644 --- a/pkg/marketregistry/server_test.go +++ b/pkg/marketregistry/server_test.go @@ -160,6 +160,9 @@ func TestServerAdminTokenPublishQuarantineAndStats(t *testing.T) { if published.Data.ID != "cloud.skill.test-publish" { t.Fatalf("unexpected published artifact: %#v", published.Data) } + if published.Data.Publisher != "AnyClaw Labs" { + t.Fatalf("publisher = %q, want token publisher", published.Data.Publisher) + } var resolved struct { Data ResolvedArtifact `json:"data"` @@ -207,19 +210,17 @@ func TestServerAdminTokenPublishQuarantineAndStats(t *testing.T) { } } -func TestServerRequireAdminToken(t *testing.T) { +func TestServerRequiresAdminToken(t *testing.T) { _, err := NewServer(context.Background(), ServerConfig{ - DataDir: t.TempDir(), - RequireAdminToken: true, + DataDir: t.TempDir(), }) if err == nil || !strings.Contains(err.Error(), "admin token is required") { t.Fatalf("expected missing admin token error, got %v", err) } server, err := NewServer(context.Background(), ServerConfig{ - DataDir: t.TempDir(), - AdminToken: "admin-secret", - RequireAdminToken: true, + DataDir: t.TempDir(), + AdminToken: "admin-secret", }) if err != nil { t.Fatal(err) @@ -230,6 +231,29 @@ func TestServerRequireAdminToken(t *testing.T) { doJSON(t, server, http.MethodGet, "/v1/admin/audit", nil, http.StatusUnauthorized, &unauthorized) } +func TestServerPublishRejectsPublisherMismatch(t *testing.T) { + server, err := NewServer(context.Background(), ServerConfig{ + DataDir: t.TempDir(), + AdminToken: "admin-secret", + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = server.Close() }) + + var token struct { + Data PublisherToken `json:"data"` + } + doJSONWithAuth(t, server, http.MethodPost, "/v1/admin/tokens", strings.NewReader(`{"publisher_id":"publisher-a"}`), "admin-secret", http.StatusOK, &token) + + publishBody := `{"artifact":{"id":"cloud.skill.publisher-mismatch","kind":"skill","name":"Mismatch","summary":"Should fail","publisher":"publisher-b","latest_version":"1.0.0","risk_level":"low","trust_level":"verified"},"versions":[{"version":"1.0.0"}]}` + var forbidden ErrorResponse + doJSONWithAuth(t, server, http.MethodPost, "/v1/publish", strings.NewReader(publishBody), token.Data.Token, http.StatusForbidden, &forbidden) + if forbidden.Error.Code != "publisher_mismatch" { + t.Fatalf("unexpected error: %#v", forbidden) + } +} + func TestServerRevokePublisherToken(t *testing.T) { server, err := NewServer(context.Background(), ServerConfig{ DataDir: t.TempDir(), @@ -302,8 +326,9 @@ func TestServerAdminDeleteArtifact(t *testing.T) { func newTestServer(t *testing.T) *Server { t.Helper() server, err := NewServer(context.Background(), ServerConfig{ - DataDir: t.TempDir(), - Seed: true, + DataDir: t.TempDir(), + Seed: true, + AdminToken: "test-admin-token", }) if err != nil { t.Fatal(err) diff --git a/pkg/marketregistry/storage.go b/pkg/marketregistry/storage.go index d4e4289f..4906e676 100644 --- a/pkg/marketregistry/storage.go +++ b/pkg/marketregistry/storage.go @@ -43,8 +43,16 @@ func (s *LocalStorage) EnsurePackage(artifact Artifact, version ArtifactVersion) if s == nil { return PackageInfo{}, fmt.Errorf("storage is not configured") } - storageKey := filepath.ToSlash(filepath.Join(artifact.ID, version.Version, "artifact.zip")) - path := filepath.Join(s.packagesDir, artifact.ID, version.Version, "artifact.zip") + artifactSegment, err := safeStorageSegment(artifact.ID) + if err != nil { + return PackageInfo{}, fmt.Errorf("artifact id: %w", err) + } + versionSegment, err := safeStorageSegment(version.Version) + if err != nil { + return PackageInfo{}, fmt.Errorf("version: %w", err) + } + storageKey := filepath.ToSlash(filepath.Join(artifactSegment, versionSegment, "artifact.zip")) + path := filepath.Join(s.packagesDir, artifactSegment, versionSegment, "artifact.zip") if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return PackageInfo{}, err } @@ -64,6 +72,17 @@ func (s *LocalStorage) EnsurePackage(artifact Artifact, version ArtifactVersion) return info, nil } +func safeStorageSegment(value string) (string, error) { + value = strings.TrimSpace(value) + if value == "" { + return "", fmt.Errorf("value is required") + } + if value == "." || value == ".." || strings.ContainsAny(value, `/\`) { + return "", fmt.Errorf("value must be a single path segment") + } + return value, nil +} + func (s *LocalStorage) Open(storageKey string) (*os.File, error) { if s == nil { return nil, fmt.Errorf("storage is not configured") diff --git a/pkg/marketregistry/storage_test.go b/pkg/marketregistry/storage_test.go new file mode 100644 index 00000000..996f99fa --- /dev/null +++ b/pkg/marketregistry/storage_test.go @@ -0,0 +1,33 @@ +package marketregistry + +import ( + "strings" + "testing" +) + +func TestLocalStorageRejectsUnsafePackageSegments(t *testing.T) { + storage, err := NewLocalStorage(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + _, err = storage.EnsurePackage(Artifact{ + ID: "../escape", + Kind: ArtifactKindSkill, + Name: "Escape", + Summary: "unsafe", + }, ArtifactVersion{Version: "1.0.0"}) + if err == nil || !strings.Contains(err.Error(), "artifact id") { + t.Fatalf("expected unsafe artifact id error, got %v", err) + } + + _, err = storage.EnsurePackage(Artifact{ + ID: "cloud.skill.safe", + Kind: ArtifactKindSkill, + Name: "Safe", + Summary: "safe", + }, ArtifactVersion{Version: "../escape"}) + if err == nil || !strings.Contains(err.Error(), "version") { + t.Fatalf("expected unsafe version error, got %v", err) + } +}