From 0fe497af7bcb517611282bb7220bdffc3440ed83 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 5 May 2026 11:42:38 -0400 Subject: [PATCH 1/2] fix(arrow/flight): deliver response headers eagerly in streaming client middleware For streaming RPCs like Handshake, `ClientHeadersMiddleware.HeadersReceived` was only invoked from `finishFn`, which fires when `Recv()` returns `io.EOF` or an error. `AuthenticateBasicToken` calls `Recv()` only once; if the server returns a `HandshakeResponse` payload (common when the Handshake response carries auth data or a session cookie), `Recv()` returns that message rather than `io.EOF` and the middleware never observes the response headers. As a result the cookie middleware fails to capture `Set-Cookie` from Handshake, and subsequent RPCs are sent without the session cookie. This change makes `clientStream.Header()` invoke `HeadersReceived` at-most-once (guarded by an atomic flag) with the response metadata the first time headers are successfully retrieved. The existing `finishFn` path is preserved so cookies arriving in trailers are still captured. Fixes #755 --- arrow/flight/client.go | 25 ++ arrow/flight/handshake_cookie_test.go | 324 ++++++++++++++++++++++++++ 2 files changed, 349 insertions(+) create mode 100644 arrow/flight/handshake_cookie_test.go diff --git a/arrow/flight/client.go b/arrow/flight/client.go index 96eb0e6b3..92699a59e 100644 --- a/arrow/flight/client.go +++ b/arrow/flight/client.go @@ -175,6 +175,19 @@ func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware desc: desc, finishFn: finishFunc, } + if isHdrs { + // Deliver response headers to the middleware as soon as they + // are first retrieved via Header(), rather than waiting for + // the stream to finish. This is necessary for streaming RPCs + // like Handshake where the caller may inspect headers (e.g. + // Set-Cookie) and issue subsequent RPCs before the stream + // reaches io.EOF (e.g. when the server sends a response + // payload that causes Recv to return a message instead of + // EOF). See GH-755. + newCS.onHeaders = func(md metadata.MD) { + hdrs.HeadersReceived(csCtx, md) + } + } // The `ClientStream` interface allows one to omit calling `Recv` if it's // known that the result will be `io.EOF`. See // http://stackoverflow.com/q/42915337 @@ -193,12 +206,24 @@ type clientStream struct { grpc.ClientStream desc *grpc.StreamDesc finishFn func(error) + + // onHeaders, when non-nil, is invoked at most once with the response + // metadata the first time Header() returns successfully. It allows + // middleware (e.g. cookie middleware) to observe server headers as + // soon as they arrive on streaming RPCs, rather than waiting for the + // stream to finish via finishFn. See GH-755. + onHeaders func(md metadata.MD) + headersObserved atomic.Bool } func (cs *clientStream) Header() (metadata.MD, error) { md, err := cs.ClientStream.Header() if err != nil { cs.finishFn(err) + return md, err + } + if cs.onHeaders != nil && cs.headersObserved.CompareAndSwap(false, true) { + cs.onHeaders(md) } return md, err } diff --git a/arrow/flight/handshake_cookie_test.go b/arrow/flight/handshake_cookie_test.go new file mode 100644 index 000000000..c5c8464ea --- /dev/null +++ b/arrow/flight/handshake_cookie_test.go @@ -0,0 +1,324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flight_test + +import ( + "context" + "encoding/base64" + "errors" + "io" + "strings" + "sync" + "testing" + + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" +) + +// handshakeCookieFlightServer is a flight server that emits Set-Cookie +// response headers (and trailers) during Handshake, simulating a server +// that creates a session during the authentication flow (see GH-755). +type handshakeCookieFlightServer struct { + flight.BaseFlightServer + + headerCookie string // cookie attached via SendHeader during Handshake + trailerCookie string // cookie attached via SetTrailer during Handshake + bearerToken string // authorization header returned during Handshake + sendPayload bool // if true, server sends a HandshakeResponse payload before closing + mu sync.Mutex + lastIncomingCook []string // incoming Cookie header values observed on ListFlights +} + +func (h *handshakeCookieFlightServer) Handshake(stream flight.FlightService_HandshakeServer) error { + md := metadata.MD{} + if h.headerCookie != "" { + md.Append("set-cookie", h.headerCookie) + } + if h.bearerToken != "" { + md.Append("authorization", "Bearer "+h.bearerToken) + } + if len(md) > 0 { + if err := stream.SendHeader(md); err != nil { + return err + } + } + + if h.trailerCookie != "" { + stream.SetTrailer(metadata.Pairs("set-cookie", h.trailerCookie)) + } + + if h.sendPayload { + if err := stream.Send(&flight.HandshakeResponse{Payload: []byte("handshake-ok")}); err != nil { + return err + } + } + + // Drain the client stream until it closes. + for { + if _, err := stream.Recv(); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + } +} + +func (h *handshakeCookieFlightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error { + h.mu.Lock() + if md, ok := metadata.FromIncomingContext(fs.Context()); ok { + h.lastIncomingCook = append([]string(nil), md.Get("cookie")...) + } else { + h.lastIncomingCook = nil + } + h.mu.Unlock() + return nil +} + +func (h *handshakeCookieFlightServer) observedCookies() []string { + h.mu.Lock() + defer h.mu.Unlock() + return append([]string(nil), h.lastIncomingCook...) +} + +// TestHandshakeCookiePropagationViaAuthenticateBasicToken is a regression +// test for GH-755. It asserts that Set-Cookie headers returned by a +// Handshake/DoHandshake response are captured by the cookie middleware +// and attached to subsequent requests. +func TestHandshakeCookiePropagationViaAuthenticateBasicToken(t *testing.T) { + srv := &handshakeCookieFlightServer{ + headerCookie: "session_id=sess_header_abc", + bearerToken: "my-bearer-token", + } + + s := flight.NewServerWithMiddleware(nil) + s.Init("localhost:0") + s.RegisterFlightService(srv) + + go s.Serve() + defer s.Shutdown() + + creds := grpc.WithTransportCredentials(insecure.NewCredentials()) + client, err := flight.NewClientWithMiddleware( + s.Addr().String(), + nil, + []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, + creds, + ) + require.NoError(t, err) + defer client.Close() + + ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass") + require.NoError(t, err) + + // Make a follow-up RPC. The cookie middleware must have captured + // Set-Cookie from the Handshake response, and StartCall should + // attach it as a Cookie header on this call. + stream, err := client.ListFlights(ctx, &flight.Criteria{}) + require.NoError(t, err) + for { + if _, err := stream.Recv(); err != nil { + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + } + } + + cookies := srv.observedCookies() + require.Len(t, cookies, 1, "expected exactly one Cookie header, got %v", cookies) + assert.Contains(t, cookies[0], "session_id=sess_header_abc", + "cookie middleware should propagate Set-Cookie from Handshake response headers") +} + +// TestHandshakeCookiePropagationFromTrailers ensures cookies delivered as +// gRPC trailers (instead of initial metadata headers) are also captured +// by the cookie middleware during Handshake. +func TestHandshakeCookiePropagationFromTrailers(t *testing.T) { + srv := &handshakeCookieFlightServer{ + trailerCookie: "session_id=sess_trailer_xyz", + bearerToken: "my-bearer-token", + } + + s := flight.NewServerWithMiddleware(nil) + s.Init("localhost:0") + s.RegisterFlightService(srv) + + go s.Serve() + defer s.Shutdown() + + creds := grpc.WithTransportCredentials(insecure.NewCredentials()) + client, err := flight.NewClientWithMiddleware( + s.Addr().String(), + nil, + []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, + creds, + ) + require.NoError(t, err) + defer client.Close() + + ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass") + require.NoError(t, err) + + stream, err := client.ListFlights(ctx, &flight.Criteria{}) + require.NoError(t, err) + for { + if _, err := stream.Recv(); err != nil { + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + } + } + + cookies := srv.observedCookies() + require.Len(t, cookies, 1, "expected exactly one Cookie header, got %v", cookies) + assert.Contains(t, cookies[0], "session_id=sess_trailer_xyz", + "cookie middleware should propagate Set-Cookie from Handshake response trailers") +} + +// TestHandshakeCookiePropagationWithServerPayload is the precise scenario +// reported in GH-755. The server attaches a Set-Cookie header AND sends +// back a HandshakeResponse payload. AuthenticateBasicToken only calls +// stream.Recv() once, which returns the payload (not io.EOF), so the +// streaming finishFn that would normally invoke HeadersReceived never +// fires. The cookie middleware must still capture the header cookie. +func TestHandshakeCookiePropagationWithServerPayload(t *testing.T) { + srv := &handshakeCookieFlightServer{ + headerCookie: "session_id=sess_with_payload", + bearerToken: "my-bearer-token", + sendPayload: true, + } + + s := flight.NewServerWithMiddleware(nil) + s.Init("localhost:0") + s.RegisterFlightService(srv) + + go s.Serve() + defer s.Shutdown() + + creds := grpc.WithTransportCredentials(insecure.NewCredentials()) + client, err := flight.NewClientWithMiddleware( + s.Addr().String(), + nil, + []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, + creds, + ) + require.NoError(t, err) + defer client.Close() + + ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass") + require.NoError(t, err) + + stream, err := client.ListFlights(ctx, &flight.Criteria{}) + require.NoError(t, err) + for { + if _, err := stream.Recv(); err != nil { + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + } + } + + cookies := srv.observedCookies() + require.Len(t, cookies, 1, + "expected exactly one Cookie header, got %v (GH-755: cookie lost when Handshake returns a payload)", cookies) + assert.Contains(t, cookies[0], "session_id=sess_with_payload") +} + +// TestHandshakeCookieProcessedBeforeRecv verifies cookies are captured +// eagerly once stream.Header() returns successfully. This models the +// scenario where an application-level Handshake flow inspects response +// headers and makes further RPCs before draining the stream. +func TestHandshakeCookieProcessedBeforeRecv(t *testing.T) { + srv := &handshakeCookieFlightServer{ + headerCookie: "session_id=eager_capture", + } + + s := flight.NewServerWithMiddleware(nil) + s.Init("localhost:0") + s.RegisterFlightService(srv) + + go s.Serve() + defer s.Shutdown() + + cookies := flight.NewCookieMiddleware() + creds := grpc.WithTransportCredentials(insecure.NewCredentials()) + client, err := flight.NewClientWithMiddleware( + s.Addr().String(), + nil, + []flight.ClientMiddleware{flight.CreateClientMiddleware(cookies)}, + creds, + ) + require.NoError(t, err) + defer client.Close() + + // Drive the Handshake manually; inspect headers before calling Recv(). + authCtx := metadata.AppendToOutgoingContext(context.Background(), + "Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte("user:pass"))) + + stream, err := client.Handshake(authCtx) + require.NoError(t, err) + require.NoError(t, stream.CloseSend()) + + hdr, err := stream.Header() + require.NoError(t, err) + require.Contains(t, strings.Join(hdr.Get("set-cookie"), ","), "eager_capture") + + // Clone the middleware while the original Handshake stream is still + // open. If cookies were processed eagerly from the header, the clone + // should already contain the session cookie. + cloned := cookies.Clone() + + // Using the clone, make a unary-ish request against a second client + // to observe the outgoing Cookie header. + clientB, err := flight.NewClientWithMiddleware( + s.Addr().String(), + nil, + []flight.ClientMiddleware{flight.CreateClientMiddleware(cloned)}, + creds, + ) + require.NoError(t, err) + defer clientB.Close() + + ls, err := clientB.ListFlights(context.Background(), &flight.Criteria{}) + require.NoError(t, err) + for { + if _, err := ls.Recv(); err != nil { + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + } + } + + got := srv.observedCookies() + require.Len(t, got, 1, "expected cloned middleware to send cookie from eagerly captured Handshake header, got %v", got) + assert.Contains(t, got[0], "session_id=eager_capture") + + // Clean up original stream. + for { + if _, err := stream.Recv(); err != nil { + break + } + } +} From 2c94c408b877c249053020c0fbc1f3ef9a398019 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 5 May 2026 12:55:51 -0400 Subject: [PATCH 2/2] test(flight/flightsql/driver): fix GH-35328 race in TestPreparedStatementNoSchema MockServer.DoPutPreparedStatementQuery has an early-return branch for ExpectedPreparedStatementSchema != nil that was introduced after the GH-35328 flake fix landed. The original drain loop (see ff339b7a) was not copied into the new branch, so the same io.EOF race - where the server returns and closes its side of the stream before the client finishes writing the parameter record batch - still occurs on the code path used by TestPreparedStatementNoSchema. The race is timing-sensitive and manifests most often on ARM64 Debian Go 1.25 with '-asan -tags assert,test,noasm' (see CI run 25386542292), but has been observed on main CI runs dating back to at least April 2026 (runs 24804648950 and 24366361129). Adds the same `for r.Next() {}` drain to the early-return branch. The comment explicitly notes this is NOT redundant with the drain below; the two guard different success paths. --- arrow/flight/flightsql/driver/driver_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/arrow/flight/flightsql/driver/driver_test.go b/arrow/flight/flightsql/driver/driver_test.go index 39d9dfd99..82c66753a 100644 --- a/arrow/flight/flightsql/driver/driver_test.go +++ b/arrow/flight/flightsql/driver/driver_test.go @@ -1819,6 +1819,11 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flight if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) { return nil, errors.New("parameter schema: unexpected") } + // See GH-35328: drain remaining batches before returning to avoid + // the io.EOF race between server close and client Write. The other + // success path below already does this; this branch must too. + for r.Next() { + } return qry.GetPreparedStatementHandle(), nil }