diff --git a/.gitignore b/.gitignore index ca11e22..983946f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ bin/ vendor/ .DS_Store +coverage.out diff --git a/client/ym/client.go b/client/ym/client.go index 1e6b845..daa43f7 100644 --- a/client/ym/client.go +++ b/client/ym/client.go @@ -19,6 +19,7 @@ import ( const defaultBaseURL = "https://botapi.messenger.yandex.net" // HttpDoer is an interface for executing HTTP requests, typically satisfied by *http.Client. +// Implementations must be safe for concurrent use if the Client is shared across goroutines. type HttpDoer interface { Do(*http.Request) (*http.Response, error) } @@ -57,6 +58,7 @@ func NewClientWithHTTP(cfg Config, httpClient HttpDoer) *Client { // DoRequest sends an HTTP request to the Yandex Messenger API with automatic // retry and rate-limit handling according to the client configuration. +// On success (2xx), the caller is responsible for closing the returned response body. func (c *Client) DoRequest(ctx context.Context, method, path string, body any) (*http.Response, error) { var payload []byte var err error @@ -146,6 +148,84 @@ func (c *Client) DoRequest(ctx context.Context, method, path string, body any) ( return nil, fmt.Errorf("yandex-messenger/client: retries exhausted for %s %s", method, path) } +// DoMultipartRequest sends an HTTP request with a pre-built body and content type, +// applying the same retry and rate-limit logic as DoRequest. +// On success (2xx), the caller is responsible for closing the returned response body. +func (c *Client) DoMultipartRequest(ctx context.Context, method, path, contentType string, body []byte) (*http.Response, error) { + url := strings.TrimRight(c.cfg.BaseURL, "/") + path + retryCfg := c.cfg.ErrorHandling.RetryStrategy + rateCfg := c.cfg.ErrorHandling.RateLimitHandling + + attempts := retryCfg.MaxAttempts + if attempts < 1 { + attempts = 1 + } + backoff := retryCfg.InitialBackoff + if backoff <= 0 { + backoff = 500 * time.Millisecond + } + + for attempt := 1; attempt <= attempts; attempt++ { + req, reqErr := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(body)) + if reqErr != nil { + return nil, fmt.Errorf("yandex-messenger/client: build request: %w", reqErr) + } + if c.cfg.Token != "" { + req.Header.Set("Authorization", "OAuth "+c.cfg.Token) + } + req.Header.Set("Content-Type", contentType) + + resp, doErr := c.http.Do(req) + if doErr != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, fmt.Errorf("yandex-messenger/client: %w for %s %s", ctxErr, method, path) + } + var netErr net.Error + if errors.As(doErr, &netErr) && retryCfg.RetryNetwork && attempt < attempts { + time.Sleep(backoff) + backoff = NextBackoff(backoff, retryCfg.MaxBackoff) + + continue + } + + return nil, fmt.Errorf("yandex-messenger/client: %w for %s %s", doErr, method, path) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil + } + + apiErr, parseErr := c.newAPIError(method, path, resp) + if parseErr != nil { + return nil, parseErr + } + + if apiErr.Kind == ymerrors.KindRateLimited && attempt < attempts { + sleep := rateCfg.DefaultBackoff + if rateCfg.UseRetryAfter && apiErr.RetryAfter > 0 { + sleep = apiErr.RetryAfter + } + if sleep <= 0 { + sleep = rateCfg.DefaultBackoff + } + time.Sleep(sleep) + + continue + } + + if ShouldRetryHTTP(apiErr.HTTPStatus, retryCfg.RetryHTTP) && attempt < attempts { + time.Sleep(backoff) + backoff = NextBackoff(backoff, retryCfg.MaxBackoff) + + continue + } + + return nil, apiErr + } + + return nil, fmt.Errorf("yandex-messenger/client: retries exhausted for %s %s", method, path) +} + func (c *Client) newAPIError(method, path string, resp *http.Response) (*ymerrors.APIError, error) { defer resp.Body.Close() @@ -171,6 +251,12 @@ func (c *Client) newAPIError(method, path string, resp *http.Response) (*ymerror kind = ymerrors.KindInvalidToken case http.StatusBadRequest: kind = ymerrors.KindBadRequest + case http.StatusNotFound: + kind = ymerrors.KindNotFound + case http.StatusConflict: + kind = ymerrors.KindConflict + case http.StatusRequestEntityTooLarge: + kind = ymerrors.KindPayloadTooLarge default: if resp.StatusCode >= 500 { kind = ymerrors.KindNetwork @@ -282,6 +368,11 @@ func parseRetryAfter(value string) time.Duration { if secs, err := strconv.Atoi(value); err == nil && secs > 0 { return time.Duration(secs) * time.Second } + if t, err := time.Parse(time.RFC1123, value); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } return 0 } diff --git a/client/ym/client_error_test.go b/client/ym/client_error_test.go index ad3e68a..79dce33 100644 --- a/client/ym/client_error_test.go +++ b/client/ym/client_error_test.go @@ -70,6 +70,57 @@ func TestNewAPIErrorInvalidToken(t *testing.T) { } } +func TestNewAPIErrorNotFound(t *testing.T) { + client := &Client{} + resp := &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"description":"not found"}`)), + Header: http.Header{}, + } + + apiErr, err := client.newAPIError("GET", "/path", resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if apiErr.Kind != ymerrors.KindNotFound { + t.Fatalf("expected KindNotFound, got %v", apiErr.Kind) + } +} + +func TestNewAPIErrorConflict(t *testing.T) { + client := &Client{} + resp := &http.Response{ + StatusCode: http.StatusConflict, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"description":"conflict"}`)), + Header: http.Header{}, + } + + apiErr, err := client.newAPIError("POST", "/path", resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if apiErr.Kind != ymerrors.KindConflict { + t.Fatalf("expected KindConflict, got %v", apiErr.Kind) + } +} + +func TestNewAPIErrorPayloadTooLarge(t *testing.T) { + client := &Client{} + resp := &http.Response{ + StatusCode: http.StatusRequestEntityTooLarge, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"description":"too large"}`)), + Header: http.Header{}, + } + + apiErr, err := client.newAPIError("POST", "/path", resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if apiErr.Kind != ymerrors.KindPayloadTooLarge { + t.Fatalf("expected KindPayloadTooLarge, got %v", apiErr.Kind) + } +} + func TestNewAPIErrorNoBody(t *testing.T) { client := &Client{} resp := &http.Response{ diff --git a/client/ym/client_multipart_test.go b/client/ym/client_multipart_test.go new file mode 100644 index 0000000..bb37839 --- /dev/null +++ b/client/ym/client_multipart_test.go @@ -0,0 +1,125 @@ +package ym + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/rekurt/ymsdk/client/ym/ymerrors" + "github.com/rekurt/ymsdk/internal/testutil" +) + +func TestDoMultipartRequestSuccess(t *testing.T) { + doer := &testutil.FakeDoer{ + Responses: []*http.Response{ + { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":true,"message_id":1}`)), + Header: http.Header{}, + }, + }, + } + client := NewClientWithHTTP(Config{ + BaseURL: "http://example.com", + Token: "tok", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, doer) + + resp, err := client.DoMultipartRequest(context.Background(), http.MethodPost, "/upload", "multipart/form-data; boundary=abc", []byte("payload")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + req := doer.Requests[0] + if req.Header.Get("Content-Type") != "multipart/form-data; boundary=abc" { + t.Fatalf("unexpected content type: %s", req.Header.Get("Content-Type")) + } + if req.Header.Get("Authorization") != "OAuth tok" { + t.Fatalf("unexpected auth: %s", req.Header.Get("Authorization")) + } +} + +func TestDoMultipartRequestRetryOn500(t *testing.T) { + doer := &testutil.FakeDoer{ + Responses: []*http.Response{ + { + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":false}`)), + Header: http.Header{}, + }, + { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":true}`)), + Header: http.Header{}, + }, + }, + } + client := NewClientWithHTTP(Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{ + MaxAttempts: 2, + InitialBackoff: 1, + }, + }, + }, doer) + + resp, err := client.DoMultipartRequest(context.Background(), http.MethodPost, "/upload", "multipart/form-data", []byte("data")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if doer.CallCount() != 2 { + t.Fatalf("expected 2 attempts, got %d", doer.CallCount()) + } +} + +func TestDoMultipartRequestRateLimitFallback(t *testing.T) { + doer := &testutil.FakeDoer{ + Responses: []*http.Response{ + { + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":false}`)), + Header: http.Header{"Retry-After": []string{"0"}}, + }, + { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"ok":true}`)), + Header: http.Header{}, + }, + }, + } + client := NewClientWithHTTP(Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{ + MaxAttempts: 2, + InitialBackoff: 1, + }, + RateLimitHandling: ymerrors.RateLimitHandling{ + UseRetryAfter: true, + DefaultBackoff: 1, + }, + }, + }, doer) + + resp, err := client.DoMultipartRequest(context.Background(), http.MethodPost, "/upload", "multipart/form-data", []byte("data")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if doer.CallCount() != 2 { + t.Fatalf("expected 2 attempts (rate limit + retry), got %d", doer.CallCount()) + } +} diff --git a/client/ym/files/service.go b/client/ym/files/service.go index 42c65e4..cdc2602 100644 --- a/client/ym/files/service.go +++ b/client/ym/files/service.go @@ -4,20 +4,21 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "mime/multipart" - "net" "net/http" "net/textproto" "strings" - "time" "github.com/rekurt/ymsdk/client/ym" "github.com/rekurt/ymsdk/client/ym/ymerrors" ) +func sanitizeFilename(name string) string { + return strings.NewReplacer(`"`, `\"`, `\`, `\\`).Replace(name) +} + // Service provides low-level file sending with raw byte payloads. // For higher-level file operations with io.Reader, use the messages service. type Service struct { @@ -70,7 +71,7 @@ func (s *Service) send( return nil, fmt.Errorf("yandex-messenger/files: build multipart: %w", err) } - resp, err := s.doMultipartWithRetry(ctx, http.MethodPost, "/bot/v1/messages/sendFile", boundaryContentType, body) + resp, err := s.client.DoMultipartRequest(ctx, http.MethodPost, "/bot/v1/messages/sendFile", boundaryContentType, body) if err != nil { return nil, err } @@ -113,7 +114,7 @@ func buildMultipartBody( } headers := textproto.MIMEHeader{} - headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="document"; filename="%s"`, filename)) + headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="document"; filename="%s"`, sanitizeFilename(filename))) if ct != "" { headers.Set("Content-Type", ct) } @@ -131,82 +132,3 @@ func buildMultipartBody( return buf.Bytes(), writer.FormDataContentType(), nil } - -func (s *Service) doMultipartWithRetry( - ctx context.Context, method, path, contentType string, body []byte, -) (*http.Response, error) { - cfg := s.client.Config() - retryCfg := cfg.ErrorHandling.RetryStrategy - rateCfg := cfg.ErrorHandling.RateLimitHandling - - attempts := retryCfg.MaxAttempts - if attempts < 1 { - attempts = 1 - } - backoff := retryCfg.InitialBackoff - if backoff <= 0 { - backoff = 500 * time.Millisecond - } - - url := strings.TrimRight(cfg.BaseURL, "/") + path - - for attempt := 1; attempt <= attempts; attempt++ { - req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("yandex-messenger/files: build request: %w", err) - } - if token := cfg.Token; token != "" { - req.Header.Set("Authorization", "OAuth "+token) - } - req.Header.Set("Content-Type", contentType) - - resp, doErr := s.client.HTTPDoer().Do(req) - if doErr != nil { - if ctxErr := ctx.Err(); ctxErr != nil { - return nil, fmt.Errorf("yandex-messenger/files: %w for %s %s", ctxErr, method, path) - } - var netErr net.Error - if errors.As(doErr, &netErr) && retryCfg.RetryNetwork && attempt < attempts { - time.Sleep(backoff) - backoff = ym.NextBackoff(backoff, retryCfg.MaxBackoff) - - continue - } - - return nil, fmt.Errorf("yandex-messenger/files: %w for %s %s", doErr, method, path) - } - - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return resp, nil - } - - apiErr, parseErr := s.client.NewAPIError(method, path, resp) - if parseErr != nil { - return nil, parseErr - } - - if apiErr.Kind == ymerrors.KindRateLimited && attempt < attempts { - sleep := rateCfg.DefaultBackoff - if rateCfg.UseRetryAfter && apiErr.RetryAfter > 0 { - sleep = apiErr.RetryAfter - } - if sleep <= 0 { - sleep = rateCfg.DefaultBackoff - } - time.Sleep(sleep) - - continue - } - - if ym.ShouldRetryHTTP(apiErr.HTTPStatus, retryCfg.RetryHTTP) && attempt < attempts { - time.Sleep(backoff) - backoff = ym.NextBackoff(backoff, retryCfg.MaxBackoff) - - continue - } - - return nil, apiErr - } - - return nil, fmt.Errorf("yandex-messenger/files: retries exhausted for %s %s", method, path) -} diff --git a/client/ym/messages/attachments.go b/client/ym/messages/attachments.go index cc027c0..3547014 100644 --- a/client/ym/messages/attachments.go +++ b/client/ym/messages/attachments.go @@ -8,25 +8,31 @@ import ( "fmt" "io" "mime/multipart" - "net" "net/http" "net/textproto" "net/url" "strings" - "time" "github.com/rekurt/ymsdk/client/ym" "github.com/rekurt/ymsdk/client/ym/ymerrors" ) +const maxGalleryImages = 10 + +// sanitizeFilename escapes special characters in a filename for Content-Disposition headers. +func sanitizeFilename(name string) string { + return strings.NewReplacer(`"`, `\"`, `\`, `\\`).Replace(name) +} + // SendFileRequest contains parameters for sending a file attachment. // Exactly one of ChatID or Login must be set. type SendFileRequest struct { - ChatID *ym.ChatID - Login *ym.UserLogin - ThreadID *ym.ThreadID - Document io.Reader - Filename string + ChatID *ym.ChatID + Login *ym.UserLogin + ThreadID *ym.ThreadID + Document io.Reader + Filename string + SuggestButtons *ym.SuggestButtons } // FileMeta holds metadata about a downloaded file. @@ -39,11 +45,12 @@ type FileMeta struct { // SendImageRequest contains parameters for sending an image attachment. // Exactly one of ChatID or Login must be set. type SendImageRequest struct { - ChatID *ym.ChatID - Login *ym.UserLogin - ThreadID *ym.ThreadID - Image io.Reader - Filename string + ChatID *ym.ChatID + Login *ym.UserLogin + ThreadID *ym.ThreadID + Image io.Reader + Filename string + SuggestButtons *ym.SuggestButtons } // FilePart represents a single file in a gallery upload. @@ -55,10 +62,12 @@ type FilePart struct { // SendGalleryRequest contains parameters for sending a gallery of images. // Exactly one of ChatID or Login must be set. type SendGalleryRequest struct { - ChatID *ym.ChatID - Login *ym.UserLogin - ThreadID *ym.ThreadID - Images []FilePart + ChatID *ym.ChatID + Login *ym.UserLogin + ThreadID *ym.ThreadID + Images []FilePart + Text string + SuggestButtons *ym.SuggestButtons } // DeleteMessageRequest contains parameters for deleting a message. @@ -79,7 +88,7 @@ func (s *Service) SendFile(ctx context.Context, req *SendFileRequest) (*ym.Messa return nil, errors.New("document and filename are required") } payload, contentType, err := buildSingleFilePayload( - req.ChatID, req.Login, req.ThreadID, "document", req.Filename, req.Document, + req.ChatID, req.Login, req.ThreadID, "document", req.Filename, req.Document, req.SuggestButtons, ) if err != nil { return nil, err @@ -97,7 +106,7 @@ func (s *Service) SendImage(ctx context.Context, req *SendImageRequest) (*ym.Mes return nil, errors.New("image and filename are required") } payload, contentType, err := buildSingleFilePayload( - req.ChatID, req.Login, req.ThreadID, "image", req.Filename, req.Image, + req.ChatID, req.Login, req.ThreadID, "image", req.Filename, req.Image, req.SuggestButtons, ) if err != nil { return nil, err @@ -114,6 +123,9 @@ func (s *Service) SendGallery(ctx context.Context, req *SendGalleryRequest) (*ym if len(req.Images) == 0 { return nil, errors.New("at least one image is required") } + if len(req.Images) > maxGalleryImages { + return nil, fmt.Errorf("gallery images limit exceeded: %d (max %d)", len(req.Images), maxGalleryImages) + } var buf bytes.Buffer writer := multipart.NewWriter(&buf) @@ -132,12 +144,26 @@ func (s *Service) SendGallery(ctx context.Context, req *SendGalleryRequest) (*ym return nil, err } } + if req.Text != "" { + if err := writer.WriteField("text", req.Text); err != nil { + return nil, err + } + } + if req.SuggestButtons != nil { + sb, err := json.Marshal(req.SuggestButtons) + if err != nil { + return nil, fmt.Errorf("marshal suggest_buttons: %w", err) + } + if err := writer.WriteField("suggest_buttons", string(sb)); err != nil { + return nil, err + } + } for i, img := range req.Images { if img.Reader == nil || img.Filename == "" { return nil, fmt.Errorf("image %d missing reader or filename", i) } headers := textproto.MIMEHeader{} - headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="images"; filename="%s"`, img.Filename)) + headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="images"; filename="%s"`, sanitizeFilename(img.Filename))) part, err := writer.CreatePart(headers) if err != nil { return nil, err @@ -228,7 +254,8 @@ func (s *Service) GetFile(ctx context.Context, fileID string) (io.ReadCloser, *F } func buildSingleFilePayload( - chatID *ym.ChatID, login *ym.UserLogin, threadID *ym.ThreadID, field, filename string, reader io.Reader, + chatID *ym.ChatID, login *ym.UserLogin, threadID *ym.ThreadID, + field, filename string, reader io.Reader, suggestButtons *ym.SuggestButtons, ) ([]byte, string, error) { var buf bytes.Buffer writer := multipart.NewWriter(&buf) @@ -247,8 +274,17 @@ func buildSingleFilePayload( return nil, "", err } } + if suggestButtons != nil { + sb, err := json.Marshal(suggestButtons) + if err != nil { + return nil, "", fmt.Errorf("marshal suggest_buttons: %w", err) + } + if err := writer.WriteField("suggest_buttons", string(sb)); err != nil { + return nil, "", err + } + } headers := textproto.MIMEHeader{} - headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, field, filename)) + headers.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, field, sanitizeFilename(filename))) part, err := writer.CreatePart(headers) if err != nil { return nil, "", err @@ -264,95 +300,25 @@ func buildSingleFilePayload( } func (s *Service) doMultipart(ctx context.Context, path, contentType string, payload []byte) (*ym.Message, error) { - cfg := s.client.Config() - retryCfg := cfg.ErrorHandling.RetryStrategy - rateCfg := cfg.ErrorHandling.RateLimitHandling + resp, err := s.client.DoMultipartRequest(ctx, http.MethodPost, path, contentType, payload) + if err != nil { + return nil, err + } + defer resp.Body.Close() - attempts := retryCfg.MaxAttempts - if attempts < 1 { - attempts = 1 + var parsed struct { + OK bool `json:"ok"` + MessageID ym.MessageID `json:"message_id"` } - backoff := retryCfg.InitialBackoff - if backoff <= 0 { - backoff = 500 * time.Millisecond + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return nil, fmt.Errorf("%w: decode multipart response: %w", ymerrors.ErrInvalidResponse, err) } - - baseUrl := strings.TrimRight(cfg.BaseURL, "/") + path - - for attempt := 1; attempt <= attempts; attempt++ { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseUrl, bytes.NewReader(payload)) - if err != nil { - return nil, fmt.Errorf("yandex-messenger/messages: build request: %w", err) - } - - if cfg.Token != "" { - req.Header.Set("Authorization", "OAuth "+cfg.Token) - } - req.Header.Set("Content-Type", contentType) - - resp, doErr := s.client.HTTPDoer().Do(req) - if doErr != nil { - if ctxErr := ctx.Err(); ctxErr != nil { - return nil, fmt.Errorf("yandex-messenger/messages: %w for %s", ctxErr, path) - } - var netErr net.Error - if errors.As(doErr, &netErr) && retryCfg.RetryNetwork && attempt < attempts { - time.Sleep(backoff) - backoff = ym.NextBackoff(backoff, retryCfg.MaxBackoff) - - continue - } - - return nil, fmt.Errorf("yandex-messenger/messages: %w for %s", doErr, path) - } - - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - var parsed struct { - OK bool `json:"ok"` - Message *ym.Message `json:"message"` - MessageID ym.MessageID `json:"message_id"` - } - if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { - _ = resp.Body.Close() - - return nil, fmt.Errorf("%w: decode multipart response: %w", ymerrors.ErrInvalidResponse, err) - } - - _ = resp.Body.Close() - - if parsed.Message != nil { - return parsed.Message, nil - } - if parsed.MessageID != 0 { - return &ym.Message{ID: parsed.MessageID}, nil - } - - return nil, fmt.Errorf("%w: ok=%v message missing", ymerrors.ErrInvalidResponse, parsed.OK) - } - - apiErr, parseErr := s.client.NewAPIError(http.MethodPost, path, resp) - if parseErr != nil { - return nil, parseErr - } - - if apiErr.Kind == ymerrors.KindRateLimited && attempt < attempts { - sleep := rateCfg.DefaultBackoff - if rateCfg.UseRetryAfter && apiErr.RetryAfter > 0 { - sleep = apiErr.RetryAfter - } - time.Sleep(sleep) - - continue - } - if ym.ShouldRetryHTTP(apiErr.HTTPStatus, retryCfg.RetryHTTP) && attempt < attempts { - time.Sleep(backoff) - backoff = ym.NextBackoff(backoff, retryCfg.MaxBackoff) - - continue - } - - return nil, apiErr + if !parsed.OK { + return nil, fmt.Errorf("%w: ok=false", ymerrors.ErrInvalidResponse) + } + if parsed.MessageID != 0 { + return &ym.Message{ID: parsed.MessageID}, nil } - return nil, fmt.Errorf("yandex-messenger/messages: retries exhausted for %s", path) + return nil, fmt.Errorf("%w: message_id missing", ymerrors.ErrInvalidResponse) } diff --git a/client/ym/messages/attachments_test.go b/client/ym/messages/attachments_test.go index 1bd64fe..b508f1f 100644 --- a/client/ym/messages/attachments_test.go +++ b/client/ym/messages/attachments_test.go @@ -29,7 +29,7 @@ func TestSendFileSuccess(t *testing.T) { }, }, &testutil.FakeDoer{ Responses: []*http.Response{ - testutil.NewResponse(http.StatusOK, `{"ok":true,"message":{"message_id":1,"chat":{"id":"c1","type":"private"},"from":{"login":"u1"},"text":"file"}}`), + testutil.NewResponse(http.StatusOK, `{"ok":true,"message_id":1}`), }, }) svc := NewService(client) @@ -84,4 +84,124 @@ func TestGetFileJSONError(t *testing.T) { } } +func TestSendImageSuccess(t *testing.T) { + client := ym.NewClientWithHTTP(ym.Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, &testutil.FakeDoer{ + Responses: []*http.Response{ + testutil.NewResponse(http.StatusOK, `{"ok":true,"message_id":42}`), + }, + }) + svc := NewService(client) + msg, err := svc.SendImage(context.Background(), &SendImageRequest{ + ChatID: ptrChat("c1"), + Image: bytes.NewBufferString("png-data"), + Filename: "photo.png", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg == nil || msg.ID != 42 { + t.Fatalf("expected message with id=42, got %v", msg) + } +} + +func TestSendImageValidation(t *testing.T) { + svc := NewService(nil) + _, err := svc.SendImage(context.Background(), &SendImageRequest{}) + if err == nil { + t.Fatal("expected validation error") + } +} + +func TestSendGallerySuccess(t *testing.T) { + client := ym.NewClientWithHTTP(ym.Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, &testutil.FakeDoer{ + Responses: []*http.Response{ + testutil.NewResponse(http.StatusOK, `{"ok":true,"message_id":99}`), + }, + }) + svc := NewService(client) + msg, err := svc.SendGallery(context.Background(), &SendGalleryRequest{ + ChatID: ptrChat("c1"), + Images: []FilePart{ + {Reader: bytes.NewBufferString("img1"), Filename: "a.jpg"}, + {Reader: bytes.NewBufferString("img2"), Filename: "b.jpg"}, + }, + Text: "gallery caption", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg == nil || msg.ID != 99 { + t.Fatalf("expected message with id=99, got %v", msg) + } +} + +func TestSendGalleryTooManyImages(t *testing.T) { + svc := NewService(nil) + images := make([]FilePart, 11) + for i := range images { + images[i] = FilePart{Reader: bytes.NewBufferString("x"), Filename: "x.jpg"} + } + _, err := svc.SendGallery(context.Background(), &SendGalleryRequest{ + ChatID: ptrChat("c1"), + Images: images, + }) + if err == nil { + t.Fatal("expected error for too many images") + } +} + +func TestGetFileSuccess(t *testing.T) { + resp := testutil.NewResponse(http.StatusOK, "file-content") + resp.Header.Set("Content-Type", "application/octet-stream") + resp.ContentLength = 12 + client := ym.NewClientWithHTTP(ym.Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, &testutil.FakeDoer{ + Responses: []*http.Response{resp}, + }) + svc := NewService(client) + body, meta, err := svc.GetFile(context.Background(), "disk/abc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer body.Close() + if meta.FileID != "disk/abc" { + t.Fatalf("expected file_id=disk/abc, got %s", meta.FileID) + } + if meta.ContentType != "application/octet-stream" { + t.Fatalf("unexpected content type: %s", meta.ContentType) + } +} + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {`normal.txt`, `normal.txt`}, + {`file"name.txt`, `file\"name.txt`}, + {`path\file.txt`, `path\\file.txt`}, + {`"evil\".txt`, `\"evil\\\".txt`}, + } + for _, tc := range tests { + got := sanitizeFilename(tc.input) + if got != tc.expected { + t.Errorf("sanitizeFilename(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + func ptrChat(id ym.ChatID) *ym.ChatID { return &id } diff --git a/client/ym/messages/service.go b/client/ym/messages/service.go index 116cf9a..8312010 100644 --- a/client/ym/messages/service.go +++ b/client/ym/messages/service.go @@ -22,21 +22,31 @@ func NewService(client *ym.Client) *Service { // SendMessageOptions holds optional parameters for text message sending. type SendMessageOptions struct { - MarkImportant bool - ReplyToMessageID string + PayloadID string + ReplyToMessageID *ym.MessageID + DisableNotification *bool + Important *bool + DisableWebPagePreview *bool + ThreadID *ym.ThreadID + SuggestButtons *ym.SuggestButtons } type sendMessageRequest struct { - ChatID ym.ChatID `json:"chat_id,omitempty"` - Login ym.UserLogin `json:"login,omitempty"` - Text string `json:"text"` - MarkImportant bool `json:"mark_important,omitempty"` - ReplyToMessageID string `json:"reply_to_message_id,omitempty"` + ChatID ym.ChatID `json:"chat_id,omitempty"` + Login ym.UserLogin `json:"login,omitempty"` + Text string `json:"text"` + PayloadID string `json:"payload_id,omitempty"` + ReplyMessageID *ym.MessageID `json:"reply_message_id,omitempty"` + DisableNotification *bool `json:"disable_notification,omitempty"` + Important *bool `json:"important,omitempty"` + DisableWebPagePreview *bool `json:"disable_web_page_preview,omitempty"` + ThreadID *ym.ThreadID `json:"thread_id,omitempty"` + SuggestButtons *ym.SuggestButtons `json:"suggest_buttons,omitempty"` } type sendMessageResponse struct { - OK bool `json:"ok"` - Message *ym.Message `json:"message"` + OK bool `json:"ok"` + MessageID ym.MessageID `json:"message_id"` } // SendToChat sends a text message to a chat identified by chatID. @@ -60,7 +70,7 @@ func (s *Service) SendToLogin( } func (s *Service) send(ctx context.Context, reqBody sendMessageRequest) (*ym.Message, error) { - resp, err := s.client.DoRequest(ctx, http.MethodPost, "/bot/v1/messages/sendText", reqBody) + resp, err := s.client.DoRequest(ctx, http.MethodPost, "/bot/v1/messages/sendText/", reqBody) if err != nil { return nil, err } @@ -70,13 +80,11 @@ func (s *Service) send(ctx context.Context, reqBody sendMessageRequest) (*ym.Mes if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { return nil, fmt.Errorf("%w: decode sendText response: %w", ymerrors.ErrInvalidResponse, err) } - if !parsed.OK || parsed.Message == nil { - return nil, fmt.Errorf( - "%w: ok=%v message_present=%v", ymerrors.ErrInvalidResponse, parsed.OK, parsed.Message != nil, - ) + if !parsed.OK { + return nil, fmt.Errorf("%w: ok=false", ymerrors.ErrInvalidResponse) } - return parsed.Message, nil + return &ym.Message{ID: parsed.MessageID}, nil } func buildRequest(text string, opts *SendMessageOptions) sendMessageRequest { @@ -85,8 +93,13 @@ func buildRequest(text string, opts *SendMessageOptions) sendMessageRequest { } return sendMessageRequest{ - Text: text, - MarkImportant: opts.MarkImportant, - ReplyToMessageID: opts.ReplyToMessageID, + Text: text, + PayloadID: opts.PayloadID, + ReplyMessageID: opts.ReplyToMessageID, + DisableNotification: opts.DisableNotification, + Important: opts.Important, + DisableWebPagePreview: opts.DisableWebPagePreview, + ThreadID: opts.ThreadID, + SuggestButtons: opts.SuggestButtons, } } diff --git a/client/ym/messages/service_test.go b/client/ym/messages/service_test.go index ae82a01..efd96e9 100644 --- a/client/ym/messages/service_test.go +++ b/client/ym/messages/service_test.go @@ -19,7 +19,7 @@ func TestSendToChatSuccess(t *testing.T) { }, }, &testutil.FakeDoer{ Responses: []*http.Response{ - testutil.NewResponse(http.StatusOK, `{"ok":true,"message":{"message_id":1,"chat":{"id":"c1","type":"private"},"from":{"login":"u1"},"text":"hi"}}`), + testutil.NewResponse(http.StatusOK, `{"ok":true,"message_id":1}`), }, }) diff --git a/client/ym/polls/service.go b/client/ym/polls/service.go index beb842f..2624f2f 100644 --- a/client/ym/polls/service.go +++ b/client/ym/polls/service.go @@ -26,18 +26,19 @@ func NewService(client *ym.Client) *Service { // CreatePollRequest contains parameters for creating a new poll. // Exactly one of ChatID or Login must be set. type CreatePollRequest struct { - ChatID *ym.ChatID `json:"chat_id,omitempty"` - Login *ym.UserLogin `json:"login,omitempty"` - Title string `json:"title"` - Answers []string `json:"answers"` - MaxChoices *int `json:"max_choices,omitempty"` - IsAnonymous *bool `json:"is_anonymous,omitempty"` - PayloadID *string `json:"payload_id,omitempty"` - ReplyMessageID *ym.MessageID `json:"reply_message_id,omitempty"` - DisableNotification *bool `json:"disable_notification,omitempty"` - Important *bool `json:"important,omitempty"` - DisableWebPagePreview *bool `json:"disable_web_page_preview,omitempty"` - ThreadID *ym.ThreadID `json:"thread_id,omitempty"` + ChatID *ym.ChatID `json:"chat_id,omitempty"` + Login *ym.UserLogin `json:"login,omitempty"` + Title string `json:"title"` + Answers []string `json:"answers"` + MaxChoices *int `json:"max_choices,omitempty"` + IsAnonymous *bool `json:"is_anonymous,omitempty"` + PayloadID *string `json:"payload_id,omitempty"` + ReplyMessageID *ym.MessageID `json:"reply_message_id,omitempty"` + DisableNotification *bool `json:"disable_notification,omitempty"` + Important *bool `json:"important,omitempty"` + DisableWebPagePreview *bool `json:"disable_web_page_preview,omitempty"` + ThreadID *ym.ThreadID `json:"thread_id,omitempty"` + SuggestButtons *ym.SuggestButtons `json:"suggest_buttons,omitempty"` } // Create sends a new poll to a chat or user. diff --git a/client/ym/polls/service_test.go b/client/ym/polls/service_test.go index 1e4ee85..f7ed6f0 100644 --- a/client/ym/polls/service_test.go +++ b/client/ym/polls/service_test.go @@ -65,6 +65,56 @@ func TestGetResultsError(t *testing.T) { } } +func TestGetVotersPageSuccess(t *testing.T) { + doer := &testutil.FakeDoer{ + Responses: []*http.Response{ + testutil.NewResponse(http.StatusOK, `{"ok":true,"answer_id":1,"voted_count":2,"cursor":100,"votes":[{"timestamp":1000,"user":{"login":"u1"}},{"timestamp":1001,"user":{"login":"u2"}}]}`), + }, + } + client := ym.NewClientWithHTTP(ym.Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, doer) + svc := NewService(client) + page, err := svc.GetVotersPage(context.Background(), PollVotersParams{ + ChatID: ptrChat("c1"), + MessageID: 1, + AnswerID: 1, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if page.AnswerID != 1 || page.VotedCount != 2 || len(page.Votes) != 2 { + t.Fatalf("unexpected result: %+v", page) + } +} + +func TestGetVotersPageError(t *testing.T) { + doer := &testutil.FakeDoer{ + Responses: []*http.Response{ + testutil.NewResponse(http.StatusOK, `{"ok":false,"description":"poll not found"}`), + }, + } + client := ym.NewClientWithHTTP(ym.Config{ + BaseURL: "http://example.com", + ErrorHandling: ymerrors.ErrorHandlingConfig{ + RetryStrategy: ymerrors.RetryStrategy{MaxAttempts: 1}, + }, + }, doer) + svc := NewService(client) + _, err := svc.GetVotersPage(context.Background(), PollVotersParams{ + ChatID: ptrChat("c1"), + MessageID: 1, + AnswerID: 1, + }) + var apiErr *ymerrors.APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected api error, got %v", err) + } +} + func ptrChat(id ym.ChatID) *ym.ChatID { return &id } diff --git a/client/ym/types.go b/client/ym/types.go index 091533b..e2705da 100644 --- a/client/ym/types.go +++ b/client/ym/types.go @@ -1,6 +1,9 @@ package ym -import "time" +import ( + "encoding/json" + "time" +) // ChatType represents the type of a Yandex Messenger chat. type ChatType string @@ -60,10 +63,11 @@ type Sticker struct { // Image represents an image attachment in a message. type Image struct { - ID string `json:"id,omitempty"` - URL string `json:"url,omitempty"` + FileID string `json:"file_id,omitempty"` Width int `json:"width,omitempty"` Height int `json:"height,omitempty"` + Size int64 `json:"size,omitempty"` + Name string `json:"name,omitempty"` } // File represents a document attachment in a message. @@ -93,18 +97,19 @@ type Message struct { // Update represents an incoming update from the getUpdates endpoint. type Update struct { - UpdateID int64 `json:"update_id"` - Chat *Chat `json:"chat,omitempty"` - From *Sender `json:"from,omitempty"` - Text string `json:"text,omitempty"` - Timestamp int64 `json:"timestamp,omitempty"` - MessageID MessageID `json:"message_id,omitempty"` - ThreadID *ThreadID `json:"thread_id,omitempty"` - Forward *ForwardInfo `json:"forward,omitempty"` - Sticker *Sticker `json:"sticker,omitempty"` - Image *Image `json:"image,omitempty"` - Gallery []Image `json:"gallery,omitempty"` - Document *File `json:"document,omitempty"` + UpdateID int64 `json:"update_id"` + Chat *Chat `json:"chat,omitempty"` + From *Sender `json:"from,omitempty"` + Text string `json:"text,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + MessageID MessageID `json:"message_id,omitempty"` + ThreadID *ThreadID `json:"thread_id,omitempty"` + Forward *ForwardInfo `json:"forward,omitempty"` + Sticker *Sticker `json:"sticker,omitempty"` + Image *Image `json:"image,omitempty"` + Images []Image `json:"images,omitempty"` + Document *File `json:"document,omitempty"` + BotRequest *BotRequest `json:"bot_request,omitempty"` } // ToMessage converts an Update to a Message by promoting its fields. @@ -124,11 +129,65 @@ func (u *Update) ToMessage() *Message { Forward: u.Forward, Sticker: u.Sticker, Image: u.Image, - Gallery: u.Gallery, + Gallery: u.Images, Document: u.Document, } } +// DirectiveType constants for button actions. +const ( + DirectiveOpenURI = "open_uri" + DirectiveSendMessage = "send_message" + DirectiveServerAction = "server_action" + DirectiveSetElementsState = "set_elements_state" +) + +// Directive describes an action triggered by a button press. +type Directive struct { + Type string `json:"type"` + URI string `json:"uri,omitempty"` + Text string `json:"text,omitempty"` + Name string `json:"name,omitempty"` + Payload json.RawMessage `json:"payload,omitempty"` + IDs []string `json:"ids,omitempty"` + State string `json:"state,omitempty"` + TimeoutSeconds *int `json:"timeout_seconds,omitempty"` +} + +// InlineSuggestButton is a single button in a SuggestButtons keyboard. +type InlineSuggestButton struct { + ID string `json:"id,omitempty"` + Title string `json:"title,omitempty"` + Directives []Directive `json:"directives,omitempty"` +} + +// SuggestButtons is a keyboard of interactive buttons attached to a message. +type SuggestButtons struct { + Layout *string `json:"layout,omitempty"` + Persist *bool `json:"persist,omitempty"` + Buttons [][]InlineSuggestButton `json:"buttons"` +} + +// ServerAction represents a callback action triggered by a button. +type ServerAction struct { + Name string `json:"name"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +// BotRequestError describes an error that occurred processing a button directive. +type BotRequestError struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + Message string `json:"message,omitempty"` +} + +// BotRequest contains callback data from interactive button presses. +type BotRequest struct { + ServerAction *ServerAction `json:"server_action,omitempty"` + ElementID string `json:"element_id,omitempty"` + Errors []BotRequestError `json:"errors,omitempty"` +} + // UserRef identifies a user by login, used in member lists. type UserRef struct { Login UserLogin `json:"login"` diff --git a/client/ym/updates/service.go b/client/ym/updates/service.go index 9f72ab2..eabc111 100644 --- a/client/ym/updates/service.go +++ b/client/ym/updates/service.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "math/rand/v2" "net/http" "net/url" "strconv" @@ -129,7 +130,12 @@ func (s *Service) PollLoop( return err } } + if ctx.Err() != nil { + return ctx.Err() + } offset = &nextOffset - time.Sleep(time.Second) + //nolint:gosec // jitter for thundering herd prevention, not security + jitter := time.Duration(rand.IntN(500)) * time.Millisecond + time.Sleep(time.Second + jitter) } } diff --git a/client/ym/ymerrors/config.go b/client/ym/ymerrors/config.go index e7ca09b..0c4b38b 100644 --- a/client/ym/ymerrors/config.go +++ b/client/ym/ymerrors/config.go @@ -4,16 +4,29 @@ import "time" // RetryStrategy configures automatic retry behavior for transient failures. type RetryStrategy struct { - MaxAttempts int `json:"max_attempts" yaml:"max_attempts"` + // MaxAttempts is the total number of attempts including the initial request. + // Default: 1 (no retries). Set to 3 for typical production use. + MaxAttempts int `json:"max_attempts" yaml:"max_attempts"` + // InitialBackoff is the delay before the first retry. Doubles on each subsequent attempt. + // Default: 500ms. InitialBackoff time.Duration `json:"initial_backoff" yaml:"initial_backoff"` - MaxBackoff time.Duration `json:"max_backoff" yaml:"max_backoff"` - RetryHTTP []int `json:"retry_http" yaml:"retry_http"` - RetryNetwork bool `json:"retry_network" yaml:"retry_network"` + // MaxBackoff caps the exponential backoff growth. + // Default: 10s. + MaxBackoff time.Duration `json:"max_backoff" yaml:"max_backoff"` + // RetryHTTP lists HTTP status codes that trigger a retry. + // Default: [500, 502, 503, 504]. + RetryHTTP []int `json:"retry_http" yaml:"retry_http"` + // RetryNetwork enables automatic retry on network-level errors (DNS, TCP). + // Default: false. + RetryNetwork bool `json:"retry_network" yaml:"retry_network"` } // RateLimitHandling configures how the client reacts to HTTP 429 responses. type RateLimitHandling struct { - UseRetryAfter bool `json:"use_retry_after" yaml:"use_retry_after"` + // UseRetryAfter respects the server's Retry-After header when present. + UseRetryAfter bool `json:"use_retry_after" yaml:"use_retry_after"` + // DefaultBackoff is the fallback delay when Retry-After is not provided. + // Default: 1s. DefaultBackoff time.Duration `json:"default_backoff" yaml:"default_backoff"` } diff --git a/client/ym/ymerrors/errors.go b/client/ym/ymerrors/errors.go index a5d3800..ca70191 100644 --- a/client/ym/ymerrors/errors.go +++ b/client/ym/ymerrors/errors.go @@ -23,6 +23,12 @@ const ( KindBadRequest // KindNetwork indicates a transport-level failure (DNS, TCP, 5xx). KindNetwork + // KindNotFound indicates the requested resource was not found (HTTP 404). + KindNotFound + // KindConflict indicates a conflict with the current state (HTTP 409). + KindConflict + // KindPayloadTooLarge indicates the request body exceeds the size limit (HTTP 413). + KindPayloadTooLarge ) // Sentinel errors for use with errors.Is. @@ -30,6 +36,10 @@ var ( ErrRateLimited = errors.New("yandex-messenger: rate limited") ErrInvalidToken = errors.New("yandex-messenger: invalid token") ErrUnauthorized = errors.New("yandex-messenger: unauthorized") + ErrBadRequest = errors.New("yandex-messenger: bad request") + ErrNotFound = errors.New("yandex-messenger: not found") + ErrConflict = errors.New("yandex-messenger: conflict") + ErrPayloadTooLarge = errors.New("yandex-messenger: payload too large") ErrRequestTimeout = errors.New("yandex-messenger: request timeout") ErrNetworkError = errors.New("yandex-messenger: network error") ErrInvalidResponse = errors.New("yandex-messenger: invalid response") @@ -98,8 +108,16 @@ func (e *APIError) Unwrap() error { return ErrInvalidToken case KindUnauthorized: return ErrUnauthorized + case KindBadRequest: + return ErrBadRequest case KindNetwork: return ErrNetworkError + case KindNotFound: + return ErrNotFound + case KindConflict: + return ErrConflict + case KindPayloadTooLarge: + return ErrPayloadTooLarge default: return nil } diff --git a/client/ym/ymerrors/errors_test.go b/client/ym/ymerrors/errors_test.go index 131e161..efa40d9 100644 --- a/client/ym/ymerrors/errors_test.go +++ b/client/ym/ymerrors/errors_test.go @@ -46,3 +46,45 @@ func TestAPIErrorAs(t *testing.T) { t.Fatalf("unexpected kind: %v", target.Kind) } } + +func TestAPIErrorUnwrapBadRequest(t *testing.T) { + err := &APIError{Kind: KindBadRequest} + if !errors.Is(err, ErrBadRequest) { + t.Fatalf("expected errors.Is to match ErrBadRequest") + } +} + +func TestAPIErrorUnwrapNetwork(t *testing.T) { + err := &APIError{Kind: KindNetwork} + if !errors.Is(err, ErrNetworkError) { + t.Fatalf("expected errors.Is to match ErrNetworkError") + } +} + +func TestAPIErrorUnwrapNotFound(t *testing.T) { + err := &APIError{Kind: KindNotFound} + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected errors.Is to match ErrNotFound") + } +} + +func TestAPIErrorUnwrapConflict(t *testing.T) { + err := &APIError{Kind: KindConflict} + if !errors.Is(err, ErrConflict) { + t.Fatalf("expected errors.Is to match ErrConflict") + } +} + +func TestAPIErrorUnwrapPayloadTooLarge(t *testing.T) { + err := &APIError{Kind: KindPayloadTooLarge} + if !errors.Is(err, ErrPayloadTooLarge) { + t.Fatalf("expected errors.Is to match ErrPayloadTooLarge") + } +} + +func TestAPIErrorUnwrapUnknown(t *testing.T) { + err := &APIError{Kind: KindUnknown} + if err.Unwrap() != nil { + t.Fatalf("expected KindUnknown to unwrap to nil") + } +} diff --git a/examples/basic_send/main.go b/examples/basic_send/main.go index 1700120..e2e0690 100644 --- a/examples/basic_send/main.go +++ b/examples/basic_send/main.go @@ -7,6 +7,7 @@ import ( "log" "os" "os/signal" + "strconv" "time" "go.uber.org/zap" @@ -70,9 +71,17 @@ func main() { var opts *messages.SendMessageOptions if *replyTo != "" || *important { - opts = &messages.SendMessageOptions{ - ReplyToMessageID: *replyTo, - MarkImportant: *important, + opts = &messages.SendMessageOptions{} + if *important { + opts.Important = ym.Ptr(true) + } + if *replyTo != "" { + v, err := strconv.ParseInt(*replyTo, 10, 64) + if err != nil { + log.Fatalf("invalid reply-to message ID: %v", err) + } + mid := ym.MessageID(v) + opts.ReplyToMessageID = &mid } } diff --git a/examples/integration/main.go b/examples/integration/main.go index d5e81cd..70bc6d3 100644 --- a/examples/integration/main.go +++ b/examples/integration/main.go @@ -86,7 +86,7 @@ func main() { if chatID != "" { section("sendText to chat") msg, err := cs.Messages.SendToChat(ctx, chatID, "integration: hello from ymsdk", &messages.SendMessageOptions{ - MarkImportant: true, + Important: ym.Ptr(true), }) if err != nil { logErr("sendText(chat)", err) diff --git a/examples/poller/main.go b/examples/poller/main.go index a3f59d5..3510aca 100644 --- a/examples/poller/main.go +++ b/examples/poller/main.go @@ -95,11 +95,11 @@ func logUpdate(logger *zap.Logger, u ym.Update) { case u.Sticker != nil: log.Printf("[%s] %s sent sticker: %s (id=%s)", chatID, sender, u.Sticker.Emoji, u.Sticker.ID) - case len(u.Gallery) > 0: - log.Printf("[%s] %s sent gallery with %d images", chatID, sender, len(u.Gallery)) + case len(u.Images) > 0: + log.Printf("[%s] %s sent gallery with %d images", chatID, sender, len(u.Images)) case u.Image != nil: - log.Printf("[%s] %s sent image: %dx%d (id=%s)", chatID, sender, u.Image.Width, u.Image.Height, u.Image.ID) + log.Printf("[%s] %s sent image: %dx%d (id=%s)", chatID, sender, u.Image.Width, u.Image.Height, u.Image.FileID) case u.Document != nil: log.Printf("[%s] %s sent file: %s (%s, %d bytes)", chatID, sender, u.Document.Name, u.Document.MimeType, u.Document.Size) diff --git a/examples/webhook/main.go b/examples/webhook/main.go index f08a928..88973a1 100644 --- a/examples/webhook/main.go +++ b/examples/webhook/main.go @@ -149,8 +149,8 @@ func processUpdate(ctx context.Context, cs *client.YMClient, upd ym.Update) { replyText = fmt.Sprintf("Nice sticker! %s", upd.Sticker.Emoji) case upd.Image != nil: replyText = fmt.Sprintf("Got your image (%dx%d)", upd.Image.Width, upd.Image.Height) - case len(upd.Gallery) > 0: - replyText = fmt.Sprintf("Got %d images in gallery", len(upd.Gallery)) + case len(upd.Images) > 0: + replyText = fmt.Sprintf("Got %d images in gallery", len(upd.Images)) case upd.Document != nil: replyText = fmt.Sprintf("Got file: %s (%d bytes)", upd.Document.Name, upd.Document.Size) case upd.Forward != nil: @@ -163,8 +163,9 @@ func processUpdate(ctx context.Context, cs *client.YMClient, upd ym.Update) { return } + replyID := upd.MessageID opts := &messages.SendMessageOptions{ - ReplyToMessageID: fmt.Sprintf("%d", upd.MessageID), + ReplyToMessageID: &replyID, } if _, sendErr := cs.Messages.SendToChat(ctx, target, replyText, opts); sendErr != nil { diff --git a/internal/testutil/fake_doer.go b/internal/testutil/fake_doer.go index ab68f35..7e6745b 100644 --- a/internal/testutil/fake_doer.go +++ b/internal/testutil/fake_doer.go @@ -1,21 +1,30 @@ package testutil -import "net/http" +import ( + "net/http" + "sync" +) // FakeDoer is a mock HTTP client for testing. It replays paired Responses and // Errors slices in order: the first call returns Responses[0]/Errors[0], the // second returns Responses[1]/Errors[1], and so on. If an index exceeds // the length of either slice, nil is returned for that component. // All incoming requests are recorded in the Requests slice. +// +// FakeDoer is safe for concurrent use. type FakeDoer struct { Responses []*http.Response Errors []error Requests []*http.Request idx int + mu sync.Mutex } // Do records the request and returns the next Response/Error pair. func (f *FakeDoer) Do(req *http.Request) (*http.Response, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.Requests = append(f.Requests, req) if f.idx >= len(f.Responses) && f.idx >= len(f.Errors) { return nil, nil @@ -39,11 +48,17 @@ func (f *FakeDoer) Do(req *http.Request) (*http.Response, error) { // Reset clears recorded requests and resets the response index to zero, // allowing the FakeDoer to be reused across subtests. func (f *FakeDoer) Reset() { + f.mu.Lock() + defer f.mu.Unlock() + f.idx = 0 f.Requests = nil } // CallCount returns the number of requests that have been made. func (f *FakeDoer) CallCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.Requests) } diff --git a/internal/testutil/fake_doer_test.go b/internal/testutil/fake_doer_test.go new file mode 100644 index 0000000..b6860ec --- /dev/null +++ b/internal/testutil/fake_doer_test.go @@ -0,0 +1,50 @@ +package testutil + +import ( + "net/http" + "sync" + "testing" +) + +func TestFakeDoerConcurrentAccess(t *testing.T) { + t.Parallel() + + responses := make([]*http.Response, 100) + for i := range responses { + responses[i] = &http.Response{StatusCode: http.StatusOK} + } + + fd := &FakeDoer{Responses: responses} + + var wg sync.WaitGroup + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + _, _ = fd.Do(req) + }() + } + wg.Wait() + + if fd.CallCount() != 100 { + t.Fatalf("expected 100 calls, got %d", fd.CallCount()) + } +} + +func TestFakeDoerReset(t *testing.T) { + fd := &FakeDoer{ + Responses: []*http.Response{ + {StatusCode: http.StatusOK}, + }, + } + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + _, _ = fd.Do(req) + if fd.CallCount() != 1 { + t.Fatalf("expected 1 call, got %d", fd.CallCount()) + } + fd.Reset() + if fd.CallCount() != 0 { + t.Fatalf("expected 0 calls after reset, got %d", fd.CallCount()) + } +} diff --git a/middleware/http_logger.go b/middleware/http_logger.go index 14f3e66..1d5ddc5 100644 --- a/middleware/http_logger.go +++ b/middleware/http_logger.go @@ -8,6 +8,7 @@ import ( ) // HTTPLogger wraps an http.Client to log raw request/response bodies. +// It is safe for concurrent use because each Do call operates on distinct request/response objects. type HTTPLogger struct { client *http.Client debugLogger *DebugLogger diff --git a/middleware/http_logger_test.go b/middleware/http_logger_test.go new file mode 100644 index 0000000..bd30b3f --- /dev/null +++ b/middleware/http_logger_test.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "go.uber.org/zap" +) + +func TestHTTPLoggerDo(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + logger, _ := zap.NewDevelopment() + debugLogger := NewDebugLogger(logger, LogLevelDebug) + hl := NewHTTPLogger(server.Client(), debugLogger) + + req, err := http.NewRequest(http.MethodPost, server.URL+"/test", strings.NewReader(`{"text":"hello"}`)) + if err != nil { + t.Fatalf("build request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := hl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if string(body) != `{"ok":true}` { + t.Fatalf("body mismatch: %s", body) + } +} + +func TestHTTPLoggerNilDefaults(t *testing.T) { + hl := NewHTTPLogger(nil, nil) + if hl.client == nil || hl.debugLogger == nil { + t.Fatal("expected non-nil defaults") + } +}