diff --git a/pkg/dxf/framework/storage/BUILD.bazel b/pkg/dxf/framework/storage/BUILD.bazel index a7d93f07dfec7..66218368307a5 100644 --- a/pkg/dxf/framework/storage/BUILD.bazel +++ b/pkg/dxf/framework/storage/BUILD.bazel @@ -49,7 +49,7 @@ go_test( ], embed = [":storage"], flaky = True, - shard_count = 28, + shard_count = 29, deps = [ "//pkg/config", "//pkg/dxf/framework/proto", diff --git a/pkg/dxf/framework/storage/task_state_test.go b/pkg/dxf/framework/storage/task_state_test.go index bb8e5aa40fa68..f51724a05a787 100644 --- a/pkg/dxf/framework/storage/task_state_test.go +++ b/pkg/dxf/framework/storage/task_state_test.go @@ -21,12 +21,14 @@ import ( "slices" "sync/atomic" "testing" + "time" "github.com/pingcap/tidb/pkg/dxf/framework/proto" "github.com/pingcap/tidb/pkg/dxf/framework/storage" "github.com/pingcap/tidb/pkg/dxf/framework/testutil" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/testfailpoint" tidbutil "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -135,6 +137,33 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) } +func TestWithNewTxnRollbackOnCanceledCtx(t *testing.T) { + _, _ = testkit.CreateMockStoreAndDomain(t) + gm, err := storage.GetTaskManager() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(util.WithInternalSourceType(context.Background(), kv.InternalDistTask)) + require.NotPanics(t, func() { + err := gm.WithNewTxn(ctx, func(se sessionctx.Context) error { + timer := time.AfterFunc(100*time.Millisecond, cancel) + defer timer.Stop() + + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "select sleep(10)") + if err != nil { + return err + } + return ctx.Err() + }) + require.ErrorIs(t, err, context.Canceled) + }) + + verifyCtx := util.WithInternalSourceType(context.Background(), kv.InternalDistTask) + require.NoError(t, gm.WithNewTxn(verifyCtx, func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(verifyCtx, se.GetSQLExecutor(), "select 1") + return err + })) +} + func TestUpdateTaskExtraParams(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) diff --git a/pkg/dxf/framework/storage/task_table.go b/pkg/dxf/framework/storage/task_table.go index a385189a37026..ffac547daa08e 100644 --- a/pkg/dxf/framework/storage/task_table.go +++ b/pkg/dxf/framework/storage/task_table.go @@ -213,6 +213,9 @@ func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) err func (mgr *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { ctx = clitutil.WithInternalSourceType(ctx, kv.InternalDistTask) return mgr.WithNewSession(func(se sessionctx.Context) (err error) { + // Keep BEGIN on the SQL path so the session enters transaction mode with the usual statement semantics. + // Commit / rollback use session methods instead, because cleanup still has to finish after caller + // cancellation and issuing SQL text there can leave the pooled internal session with a live txn. _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "begin") if err != nil { return err @@ -220,14 +223,15 @@ func (mgr *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Co success := false defer func() { - sql := "rollback" if success { - sql = "commit" - } - _, commitErr := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql) - if err == nil && commitErr != nil { - err = commitErr + commitErr := se.CommitTxn(ctx) + if err == nil && commitErr != nil { + err = commitErr + } + return } + + se.RollbackTxn(clitutil.WithInternalSourceType(context.Background(), kv.InternalDistTask)) }() if err = fn(se); err != nil { diff --git a/pkg/executor/select.go b/pkg/executor/select.go index 2b5de8c9901df..c58538b181ad1 100644 --- a/pkg/executor/select.go +++ b/pkg/executor/select.go @@ -1003,6 +1003,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars.MemTracker.SessionID.Store(vars.ConnectionID) vars.MemTracker.Killer = &vars.SQLKiller vars.DiskTracker.Killer = &vars.SQLKiller + if vars.InRestrictedSQL && vars.InternalSQLScanUserTable { + failpoint.InjectCall("beforeResetSQLKillerForTTLScan", s) + } vars.SQLKiller.Reset() vars.SQLKiller.ConnID.Store(vars.ConnectionID) vars.ResetRelevantOptVarsAndFixes(false) diff --git a/pkg/executor/test/executor/executor_test.go b/pkg/executor/test/executor/executor_test.go index 0ab155a4f01f6..03ca2edb61a82 100644 --- a/pkg/executor/test/executor/executor_test.go +++ b/pkg/executor/test/executor/executor_test.go @@ -2589,7 +2589,7 @@ func TestQueryWithKill(t *testing.T) { } } if err != nil { - require.Equal(t, context.Canceled, err) + require.ErrorIs(t, err, context.Canceled) } if rs != nil { rs.Close() diff --git a/pkg/session/session.go b/pkg/session/session.go index 9e4456de0d331..32e11a49e4dcc 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2441,6 +2441,10 @@ func (s *session) executeStmtImpl(ctx context.Context, stmtNode ast.StmtNode) (s if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { return nil, err } + // ResetContextOfStmt clears SQLKiller, so honor a canceled caller before executing the next statement. + if err := ctx.Err(); err != nil { + return nil, err + } ruv2Metrics := execdetails.RUV2MetricsFromContext(ctx) if ruv2Metrics == nil { ruv2Metrics = execdetails.NewRUV2Metrics() diff --git a/pkg/ttl/ttlworker/scan.go b/pkg/ttl/ttlworker/scan.go index 86e74014c8dc9..10cf5384b9acc 100644 --- a/pkg/ttl/ttlworker/scan.go +++ b/pkg/ttl/ttlworker/scan.go @@ -137,13 +137,14 @@ func (t *ttlScanTask) doScan(ctx context.Context, delCh chan<- *ttlDeleteTask, s } func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDeleteTask, rawSess session.Session) error { - // TODO: merge the ctx and the taskCtx in ttl scan task, to allow both "cancel" and gracefully stop workers - // now, the taskCtx is only check at the beginning of every loop taskCtx := t.ctx tracer := metrics.PhaseTracerFromCtx(ctx) defer tracer.EnterPhase(tracer.Phase()) tracer.EnterPhase(metrics.PhaseOther) + // Keep the SQL execution context canceled when either the worker or the TTL task stops. + scanCtx, cancelScanCtx := context.WithCancel(ctx) + defer cancelScanCtx() doScanFinished, setDoScanFinished := context.WithCancel(context.Background()) wg := util.WaitGroupWrapper{} wg.Run(func() { @@ -153,6 +154,7 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe case <-doScanFinished.Done(): return } + cancelScanCtx() logger := t.taskLogger(logutil.BgLogger()) logger.Info("kill the running statement in scan task because the task or worker cancelled") rawSess.KillStmt() @@ -201,7 +203,7 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe ) } - sess, restoreSession, err := NewScanSession(rawSess, t.tbl, t.ExpireTime) + sess, restoreSession, err := NewScanSession(scanCtx, rawSess, t.tbl, t.ExpireTime) if err != nil { return err } @@ -242,11 +244,11 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe } sqlStart := time.Now() - rows, retryable, sqlErr := sess.ExecuteSQLWithCheck(ctx, sql) + rows, retryable, sqlErr := sess.ExecuteSQLWithCheck(scanCtx, sql) selectInterval := time.Since(sqlStart) if sqlErr != nil { metrics.SelectErrorDuration.Observe(selectInterval.Seconds()) - needRetry := retryable && retryTimes < scanTaskExecuteSQLMaxRetry && ctx.Err() == nil && t.ctx.Err() == nil + needRetry := retryable && retryTimes < scanTaskExecuteSQLMaxRetry && scanCtx.Err() == nil logutil.BgLogger().Warn("execute query for ttl scan task failed", zap.String("SQL", sql), zap.Int("retryTimes", retryTimes), @@ -262,8 +264,8 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe tracer.EnterPhase(metrics.PhaseWaitRetry) select { - case <-ctx.Done(): - return ctx.Err() + case <-scanCtx.Done(): + return scanCtx.Err() case <-time.After(scanTaskExecuteSQLRetryInterval): } tracer.EnterPhase(metrics.PhaseOther) @@ -289,8 +291,8 @@ func (t *ttlScanTask) doScanWithSession(ctx context.Context, delCh chan<- *ttlDe tracer.EnterPhase(metrics.PhaseDispatch) select { - case <-ctx.Done(): - return ctx.Err() + case <-scanCtx.Done(): + return scanCtx.Err() case delCh <- delTask: t.statistics.IncTotalRows(len(lastResult)) } diff --git a/pkg/ttl/ttlworker/scan_integration_test.go b/pkg/ttl/ttlworker/scan_integration_test.go index 40229243313b1..6bdbb2bcdde7b 100644 --- a/pkg/ttl/ttlworker/scan_integration_test.go +++ b/pkg/ttl/ttlworker/scan_integration_test.go @@ -23,7 +23,9 @@ import ( "time" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/sessionctx/vardef" "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" "github.com/pingcap/tidb/pkg/testkit/testflag" "github.com/pingcap/tidb/pkg/ttl/cache" "github.com/pingcap/tidb/pkg/ttl/ttlworker" @@ -97,3 +99,87 @@ func TestCancelWhileScan(t *testing.T) { close(delCh) wg.Wait() } + +func TestCancelWhileScanAtStatementBoundary(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + origBatchSize := vardef.TTLScanBatchSize.Load() + vardef.TTLScanBatchSize.Store(30) + t.Cleanup(func() { + vardef.TTLScanBatchSize.Store(origBatchSize) + }) + + tk.MustExec("create table test.t (id int primary key, created_at datetime) TTL= created_at + interval 1 hour") + tk.MustExec("split table test.t between (0) and (30000) regions 30") + for i := range 30 { + tk.MustExec(fmt.Sprintf("insert into test.t values (%d, NOW() - INTERVAL 24 HOUR)", i*1000)) + } + testTable, err := dom.InfoSchema().TableByName(context.Background(), ast.NewCIStr("test"), ast.NewCIStr("t")) + require.NoError(t, err) + testPhysicalTableCache, err := cache.NewPhysicalTable(ast.NewCIStr("test"), testTable.Meta(), ast.NewCIStr("")) + require.NoError(t, err) + + testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/store/copr/sleepCoprRequest", "return(2000)") + + taskCtx, cancelTask := context.WithCancel(context.Background()) + defer cancelTask() + ttlTask := ttlworker.NewTTLScanTask(taskCtx, testPhysicalTableCache, &cache.TTLTask{ + JobID: "test", + TableID: testTable.Meta().ID, + ScanID: 1, + ScanRangeStart: nil, + ScanRangeEnd: nil, + ExpireTime: time.Now().Add(-12 * time.Hour), + OwnerID: "test", + OwnerAddr: "test", + OwnerHBTime: time.Now(), + Status: cache.TaskStatusRunning, + StatusUpdateTime: time.Now(), + State: &cache.TTLTaskState{}, + CreatedTime: time.Now(), + }) + + triggerCancel := make(chan struct{}) + var cancelOnce sync.Once + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/executor/beforeResetSQLKillerForTTLScan", func(stmt ast.StmtNode) { + if _, ok := stmt.(*ast.SelectStmt); !ok { + return + } + + cancelOnce.Do(func() { + cancelTask() + close(triggerCancel) + time.Sleep(100 * time.Millisecond) + }) + }) + + delCh := make(chan *ttlworker.TTLDeleteTask) + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + for range delCh { + } + }() + + doScanDone := make(chan struct{}) + go func() { + defer close(doScanDone) + ttlTask.DoScan(context.Background(), delCh, dom.AdvancedSysSessionPool()) + }() + + select { + case <-triggerCancel: + case <-time.After(10 * time.Second): + require.FailNow(t, "TTL scan SELECT was not reached") + } + + select { + case <-doScanDone: + case <-time.After(time.Second): + require.FailNow(t, "TTL scan was not canceled within 1s after statement-boundary cancel") + } + + close(delCh) + <-doneCh +} diff --git a/pkg/ttl/ttlworker/session.go b/pkg/ttl/ttlworker/session.go index d29b68cadeba9..5715bba88772a 100644 --- a/pkg/ttl/ttlworker/session.go +++ b/pkg/ttl/ttlworker/session.go @@ -196,7 +196,7 @@ func newTableSession(se session.Session, tbl *cache.PhysicalTable, expire time.T } // NewScanSession creates a session for scan -func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Time) (*ttlTableSession, func() error, error) { +func NewScanSession(ctx context.Context, se session.Session, tbl *cache.PhysicalTable, expire time.Time) (*ttlTableSession, func() error, error) { origConcurrency := se.GetSessionVars().DistSQLScanConcurrency() origPaging := se.GetSessionVars().EnablePaging se.GetSessionVars().InternalSQLScanUserTable = true @@ -218,7 +218,7 @@ func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Ti } // Set the distsql scan concurrency to 1 to reduce the number of cop tasks in TTL scan. - if _, err := se.ExecuteSQL(context.Background(), "set @@tidb_distsql_scan_concurrency=1"); err != nil { + if _, err := se.ExecuteSQL(ctx, "set @@tidb_distsql_scan_concurrency=1"); err != nil { terror.Log(restore()) return nil, nil, err } @@ -227,7 +227,7 @@ func NewScanSession(se session.Session, tbl *cache.PhysicalTable, expire time.Ti // If `tidb_enable_paging` is enabled, it may have multiple cop tasks even in one region that makes some extra // processed keys in TiKV side, see issue: https://github.com/pingcap/tidb/issues/58342. // Disable it to make the scan more efficient. - if _, err := se.ExecuteSQL(context.Background(), "set @@tidb_enable_paging=OFF"); err != nil { + if _, err := se.ExecuteSQL(ctx, "set @@tidb_enable_paging=OFF"); err != nil { terror.Log(restore()) return nil, nil, err } diff --git a/pkg/ttl/ttlworker/session_integration_test.go b/pkg/ttl/ttlworker/session_integration_test.go index 5aed8fcbdbd67..fc88d56f38abf 100644 --- a/pkg/ttl/ttlworker/session_integration_test.go +++ b/pkg/ttl/ttlworker/session_integration_test.go @@ -380,7 +380,7 @@ func TestNewScanSession(t *testing.T) { called := false require.NoError(t, ttlworker.WithSessionForTest(pool, func(se session.Session) error { require.False(t, called) - tblSe, restore, err := ttlworker.NewScanSession(se, &cache.PhysicalTable{}, time.Now()) + tblSe, restore, err := ttlworker.NewScanSession(context.Background(), se, &cache.PhysicalTable{}, time.Now()) called = true if errSQL == "" { // success case @@ -424,7 +424,7 @@ func TestNewScanSession(t *testing.T) { }, newFaultAfterCount(0))) require.NoError(t, ttlworker.WithSessionForTest(pool, func(se session.Session) error { require.False(t, called) - tblSe, restore, err := ttlworker.NewScanSession(se, &cache.PhysicalTable{}, time.Now()) + tblSe, restore, err := ttlworker.NewScanSession(context.Background(), se, &cache.PhysicalTable{}, time.Now()) called = true require.NoError(t, err) require.NotNil(t, tblSe)