From 12608303aebc3adb65b5393c376a0c0c56873cc6 Mon Sep 17 00:00:00 2001 From: Hangjie Mo Date: Wed, 7 May 2025 11:44:48 +0800 Subject: [PATCH] server: check connection is available in SQLKiller (#60685) close pingcap/tidb#57531 --- pkg/server/conn.go | 8 +++ pkg/server/conn_stmt.go | 11 ++++ .../internal/util/buffered_read_conn.go | 42 +++++++++++++ pkg/server/tests/commontest/tidb_test.go | 63 +++++++++++++++++++ pkg/util/deeptest/statictesthelper.go | 2 +- pkg/util/sqlkiller/BUILD.bazel | 1 + pkg/util/sqlkiller/sqlkiller.go | 27 ++++++++ 7 files changed, 153 insertions(+), 1 deletion(-) diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 8a53eb8adb05d..e950a77778a54 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -2070,9 +2070,17 @@ func (cc *clientConn) handleStmt( //nolint: errcheck rs.Finish() }) + fn := func() bool { + if cc.bufReadConn != nil { + return cc.bufReadConn.IsAlive() != 0 + } + return true + } + cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn) cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true) defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false) defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc() + defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil) if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil { return retryable, err } diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index 45c7ba35dee6f..740e0bb6c8ebf 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -233,6 +233,17 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args [ ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) + + fn := func() bool { + if cc.bufReadConn != nil { + return cc.bufReadConn.IsAlive() != 0 + } + return true + } + cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn) + defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil) + + //nolint:forcetypeassert retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) if err != nil { action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) diff --git a/pkg/server/internal/util/buffered_read_conn.go b/pkg/server/internal/util/buffered_read_conn.go index d6bc1f24c9bc4..e425bdb1f5b8c 100644 --- a/pkg/server/internal/util/buffered_read_conn.go +++ b/pkg/server/internal/util/buffered_read_conn.go @@ -16,7 +16,10 @@ package util import ( "bufio" + "io" "net" + "sync" + "time" ) // DefaultReaderSize is the default size of bufio.Reader. @@ -26,11 +29,15 @@ const DefaultReaderSize = 16 * 1024 type BufferedReadConn struct { net.Conn rb *bufio.Reader + // `mu` is for `IsAlive()` function. + // We use this to ensure that `SetReadDeadline` is not called concurrently. + mu *sync.Mutex } // NewBufferedReadConn creates a BufferedReadConn. func NewBufferedReadConn(conn net.Conn) *BufferedReadConn { return &BufferedReadConn{ + mu: &sync.Mutex{}, Conn: conn, rb: bufio.NewReaderSize(conn, DefaultReaderSize), } @@ -40,3 +47,38 @@ func NewBufferedReadConn(conn net.Conn) *BufferedReadConn { func (conn BufferedReadConn) Read(b []byte) (n int, err error) { return conn.rb.Read(b) } + +// Peek peeks from the connection. +func (conn BufferedReadConn) Peek(n int) ([]byte, error) { + return conn.rb.Peek(n) +} + +// IsAlive detects the connection is alive or not. +// return value < 0, means unknow +// return value = 0, means not alive +// return value = 1, means still alive +func (conn BufferedReadConn) IsAlive() int { + if conn.mu.TryLock() { + defer conn.mu.Unlock() + err := conn.SetReadDeadline(time.Now().Add(30 * time.Microsecond)) + if err != nil { + return -1 + } + // nolint:errcheck + defer conn.SetReadDeadline(time.Time{}) + // At the TCP level, a successful `Peek` operation doesn't guarantee + // the connection remains active. However, in the MySQL protocol, + // clients shouldn't send new data while the server is processing SQL. + // Therefore, we can safely assume `Peek` won't intercept any data + // during this period. Even if `Peek` does capture data, it only means + // the liveness check might be inaccurate - this won't impact the + // actual connection state or its operations. + _, err = conn.Peek(1) + if err == io.EOF { + return 0 + } else if ne, ok := err.(net.Error); ok && ne.Timeout() { + return 1 + } + } + return -1 +} diff --git a/pkg/server/tests/commontest/tidb_test.go b/pkg/server/tests/commontest/tidb_test.go index 3e8e28446ec64..c18af84b0ef92 100644 --- a/pkg/server/tests/commontest/tidb_test.go +++ b/pkg/server/tests/commontest/tidb_test.go @@ -31,6 +31,7 @@ import ( "sync/atomic" "testing" "time" + "unsafe" "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -3399,6 +3400,7 @@ func TestBatchGetTypeForRowExpr(t *testing.T) { ts.CheckRows(t, rows, "a b\nc d") }) } + func TestAuditPluginInfoForStarting(t *testing.T) { ts := servertestkit.CreateTidbTestSuite(t) @@ -3707,3 +3709,64 @@ func TestAuditPluginRetrying(t *testing.T) { runExplicitTransactionRetry(db, true) }) } + +func TestIssue57531(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + + var rsCnt int + for i := range 2 { + ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + var conn *sql.Conn + var netConn net.Conn + conn, _ = dbt.GetDB().Conn(context.Background()) + + // get the TCP connection + conn.Raw(func(driverConn any) error { + v := reflect.ValueOf(driverConn) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + f := v.FieldByName("netConn") + if f.IsValid() && f.Type().Implements(reflect.TypeOf((*net.Conn)(nil)).Elem()) { + netConn = *(*net.Conn)(unsafe.Pointer(f.UnsafeAddr())) + } + return nil + }) + + // execute `select sleep(300)` + go func() { + if i == 0 { + conn.QueryContext(context.Background(), "select sleep(300)") + } else { + stmt, err := conn.PrepareContext(context.Background(), "select sleep(?)") + require.NoError(t, err) + stmt.Exec(300) + } + }() + time.Sleep(200 * time.Millisecond) + + // have two sessions + rsCnt = 0 + rs := dbt.MustQuery("show processlist") + for rs.Next() { + rsCnt++ + } + require.Equal(t, rsCnt, 2) + + // close tcp connection + netConn.Close() + }) + + time.Sleep(10 * time.Millisecond) + + // the `select sleep(300)` is killed + ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + rsCnt = 0 + rs := dbt.MustQuery("show processlist") + for rs.Next() { + rsCnt++ + } + require.Equal(t, rsCnt, 1) + }) + } +} diff --git a/pkg/util/deeptest/statictesthelper.go b/pkg/util/deeptest/statictesthelper.go index b9493444037ae..5b4252dbdb3e2 100644 --- a/pkg/util/deeptest/statictesthelper.go +++ b/pkg/util/deeptest/statictesthelper.go @@ -142,7 +142,7 @@ func (h *staticTestHelper) assertDeepClonedEqual(t require.TestingT, valA, valB for i := range valA.NumField() { h.assertDeepClonedEqual(t, valA.Field(i), valB.Field(i), path+"."+valA.Type().Field(i).Name) } - case reflect.Ptr: + case reflect.Ptr, reflect.UnsafePointer: if valA.IsNil() && valB.IsNil() { return } diff --git a/pkg/util/sqlkiller/BUILD.bazel b/pkg/util/sqlkiller/BUILD.bazel index 5a4eacf70afd6..5a3a76d5a558e 100644 --- a/pkg/util/sqlkiller/BUILD.bazel +++ b/pkg/util/sqlkiller/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/util/dbterror/exeerrors", + "//pkg/util/intest", "//pkg/util/logutil", "@com_github_pingcap_failpoint//:failpoint", "@org_uber_go_zap//:zap", diff --git a/pkg/util/sqlkiller/sqlkiller.go b/pkg/util/sqlkiller/sqlkiller.go index da81f99ee535d..c674a1c2f02f9 100644 --- a/pkg/util/sqlkiller/sqlkiller.go +++ b/pkg/util/sqlkiller/sqlkiller.go @@ -18,9 +18,11 @@ import ( "math/rand" "sync" "sync/atomic" + "time" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) @@ -51,6 +53,9 @@ type SQLKiller struct { // InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet(). // If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O. InWriteResultSet atomic.Bool + + lastCheckTime atomic.Pointer[time.Time] + IsConnectionAlive atomic.Pointer[func() bool] } // SendKillSignal sends a kill signal to the query. @@ -122,6 +127,27 @@ func (killer *SQLKiller) HandleSignal() error { } } }) + + // Checks if the connection is alive. + // For performance reasons, the check interval should be at least `checkConnectionAliveDur`(1 second). + fn := killer.IsConnectionAlive.Load() + lastCheckTime := killer.lastCheckTime.Load() + if fn != nil { + var checkConnectionAliveDur time.Duration = time.Second + now := time.Now() + if intest.InTest { + checkConnectionAliveDur = time.Millisecond + } + if lastCheckTime == nil { + killer.lastCheckTime.Store(&now) + } else if now.Sub(*lastCheckTime) > checkConnectionAliveDur { + killer.lastCheckTime.Store(&now) + if !(*fn)() { + atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted) + } + } + } + status := atomic.LoadUint32(&killer.Signal) err := killer.getKillError(status) if status == ServerMemoryExceeded { @@ -137,4 +163,5 @@ func (killer *SQLKiller) Reset() { logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID.Load())) } atomic.StoreUint32(&killer.Signal, 0) + killer.lastCheckTime.Store(nil) }