Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2076,9 +2076,17 @@ func (cc *clientConn) handleStmt(
//nolint: errcheck
rs.Finish()
})
fn := func() bool {
if cc.bufReadConn != nil {
return cc.bufReadConn.IsAlive() != 0
}
return true
}
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn)
cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true)
defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false)
defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc()
defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil)
if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil {
return retryable, err
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args [
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails())

fn := func() bool {
if cc.bufReadConn != nil {
return cc.bufReadConn.IsAlive() != 0
}
return true
}
cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(&fn)
defer cc.ctx.GetSessionVars().SQLKiller.IsConnectionAlive.Store(nil)

//nolint:forcetypeassert
retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
if err != nil {
Expand Down
42 changes: 42 additions & 0 deletions pkg/server/internal/util/buffered_read_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package util

import (
"bufio"
"io"
"net"
"sync"
"time"
)

// DefaultReaderSize is the default size of bufio.Reader.
Expand All @@ -26,11 +29,15 @@ const DefaultReaderSize = 16 * 1024
type BufferedReadConn struct {
net.Conn
rb *bufio.Reader
// `mu` is for `IsAlive()` function.
// We use this to ensure that `SetReadDeadline` is not called concurrently.
mu *sync.Mutex
Comment thread
xhebox marked this conversation as resolved.
}

// NewBufferedReadConn creates a BufferedReadConn.
func NewBufferedReadConn(conn net.Conn) *BufferedReadConn {
return &BufferedReadConn{
mu: &sync.Mutex{},
Conn: conn,
rb: bufio.NewReaderSize(conn, DefaultReaderSize),
}
Expand All @@ -40,3 +47,38 @@ func NewBufferedReadConn(conn net.Conn) *BufferedReadConn {
func (conn BufferedReadConn) Read(b []byte) (n int, err error) {
return conn.rb.Read(b)
}

// Peek peeks from the connection.
func (conn BufferedReadConn) Peek(n int) ([]byte, error) {
return conn.rb.Peek(n)
}

// IsAlive detects the connection is alive or not.
// return value < 0, means unknow
Comment thread
xhebox marked this conversation as resolved.
// return value = 0, means not alive
// return value = 1, means still alive
func (conn BufferedReadConn) IsAlive() int {
if conn.mu.TryLock() {
defer conn.mu.Unlock()
err := conn.SetReadDeadline(time.Now().Add(30 * time.Microsecond))
if err != nil {
return -1
}
// nolint:errcheck
defer conn.SetReadDeadline(time.Time{})
// At the TCP level, a successful `Peek` operation doesn't guarantee
// the connection remains active. However, in the MySQL protocol,
// clients shouldn't send new data while the server is processing SQL.
// Therefore, we can safely assume `Peek` won't intercept any data
// during this period. Even if `Peek` does capture data, it only means
// the liveness check might be inaccurate - this won't impact the
// actual connection state or its operations.
_, err = conn.Peek(1)
if err == io.EOF {
return 0
} else if ne, ok := err.(net.Error); ok && ne.Timeout() {
return 1
}
}
return -1
}
62 changes: 62 additions & 0 deletions pkg/server/tests/commontest/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
Expand Down Expand Up @@ -3397,3 +3398,64 @@ func TestBatchGetTypeForRowExpr(t *testing.T) {
ts.CheckRows(t, rows, "a b\nc d")
})
}

func TestIssue57531(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

var rsCnt int
for i := range 2 {
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
var conn *sql.Conn
var netConn net.Conn
conn, _ = dbt.GetDB().Conn(context.Background())

// get the TCP connection
conn.Raw(func(driverConn any) error {
v := reflect.ValueOf(driverConn)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
f := v.FieldByName("netConn")
if f.IsValid() && f.Type().Implements(reflect.TypeOf((*net.Conn)(nil)).Elem()) {
netConn = *(*net.Conn)(unsafe.Pointer(f.UnsafeAddr()))
}
return nil
})

// execute `select sleep(300)`
go func() {
if i == 0 {
conn.QueryContext(context.Background(), "select sleep(300)")
} else {
stmt, err := conn.PrepareContext(context.Background(), "select sleep(?)")
require.NoError(t, err)
stmt.Exec(300)
}
}()
time.Sleep(200 * time.Millisecond)

// have two sessions
rsCnt = 0
rs := dbt.MustQuery("show processlist")
for rs.Next() {
rsCnt++
}
require.Equal(t, rsCnt, 2)

// close tcp connection
netConn.Close()
})

time.Sleep(10 * time.Millisecond)

// the `select sleep(300)` is killed
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
rsCnt = 0
rs := dbt.MustQuery("show processlist")
for rs.Next() {
rsCnt++
}
require.Equal(t, rsCnt, 1)
})
}
}
2 changes: 1 addition & 1 deletion pkg/util/deeptest/statictesthelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (h *staticTestHelper) assertDeepClonedEqual(t require.TestingT, valA, valB
for i := range valA.NumField() {
h.assertDeepClonedEqual(t, valA.Field(i), valB.Field(i), path+"."+valA.Type().Field(i).Name)
}
case reflect.Ptr:
case reflect.Ptr, reflect.UnsafePointer:
if valA.IsNil() && valB.IsNil() {
return
}
Expand Down
1 change: 1 addition & 0 deletions pkg/util/sqlkiller/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/util/dbterror/exeerrors",
"//pkg/util/intest",
"//pkg/util/logutil",
"@com_github_pingcap_failpoint//:failpoint",
"@org_uber_go_zap//:zap",
Expand Down
27 changes: 27 additions & 0 deletions pkg/util/sqlkiller/sqlkiller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"math/rand"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/util/dbterror/exeerrors"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -51,6 +53,9 @@ type SQLKiller struct {
// InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet().
// If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O.
InWriteResultSet atomic.Bool

lastCheckTime atomic.Pointer[time.Time]
IsConnectionAlive atomic.Pointer[func() bool]
}

// SendKillSignal sends a kill signal to the query.
Expand Down Expand Up @@ -122,6 +127,27 @@ func (killer *SQLKiller) HandleSignal() error {
}
}
})

// Checks if the connection is alive.
// For performance reasons, the check interval should be at least `checkConnectionAliveDur`(1 second).
fn := killer.IsConnectionAlive.Load()
lastCheckTime := killer.lastCheckTime.Load()
if fn != nil {
var checkConnectionAliveDur time.Duration = time.Second
now := time.Now()
if intest.InTest {
checkConnectionAliveDur = time.Millisecond
}
if lastCheckTime == nil {
killer.lastCheckTime.Store(&now)
} else if now.Sub(*lastCheckTime) > checkConnectionAliveDur {
killer.lastCheckTime.Store(&now)
if !(*fn)() {
atomic.CompareAndSwapUint32(&killer.Signal, 0, QueryInterrupted)
}
}
}

status := atomic.LoadUint32(&killer.Signal)
err := killer.getKillError(status)
if status == ServerMemoryExceeded {
Expand All @@ -137,4 +163,5 @@ func (killer *SQLKiller) Reset() {
logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID.Load()))
}
atomic.StoreUint32(&killer.Signal, 0)
killer.lastCheckTime.Store(nil)
}