diff --git a/README.md b/README.md index 10a70e4a..9cf3136f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,73 @@ Additionally, you can manage heartbeat logic within the (Codec)OnCron method in If you're using WebSocket, you don't need to worry about heartbeat request/response, as Getty handles this task within session.go's (Session)handleLoop method by sending and receiving WebSocket ping/pong frames. Your responsibility is to check whether the WebSocket session has timed out or not within codec.go's (Codec)OnCron method using session.go's (Session)GetActive method. -For code examples, you can refer to https://github.com/AlexStocks/getty-examples. +For code examples, you can refer to [getty-examples](https://github.com/AlexStocks/getty-examples). + +## Callback System + +Getty provides a robust callback system that allows you to register and manage callback functions for session lifecycle events. This is particularly useful for cleanup operations, resource management, and custom event handling. + +### Key Features + +- **Thread-safe operations**: All callback operations are protected by mutex locks +- **Replace semantics**: Adding with the same (handler, key) replaces the existing callback in place (position preserved) +- **Panic safety**: During session close, callbacks run in a dedicated goroutine with defer/recover; panics are logged with stack traces and do not escape the close path +- **Ordered execution**: Callbacks are executed in the order they were added + +### Usage Example + +```go +// Add a close callback +session.AddCloseCallback("cleanup", "resources", func() { + // Cleanup resources when session closes + cleanupResources() +}) + +// Remove a specific callback +// Safe to call even if the pair was never added (no-op) +session.RemoveCloseCallback("cleanup", "resources") + +// Callbacks are automatically executed when the session closes +``` + +**Note**: During session shutdown, callbacks are executed sequentially in a dedicated goroutine to preserve add-order, with defer/recover to log panics without letting them escape the close path. + +### Callback Management + +- **AddCloseCallback**: Register a callback to be executed when the session closes +- **RemoveCloseCallback**: Remove a previously registered callback (no-op if not found; safe to call multiple times) +- **Thread Safety**: All operations are thread-safe and can be called concurrently + +### Type Requirements + +The `handler` and `key` parameters must be **comparable types** that support the `==` operator: + +**✅ Supported types:** +- **Basic types**: `string`, `int`, `int8`, `int16`, `int32`, `int64`, `uint`, `uint8`, `uint16`, `uint32`, `uint64`, `uintptr`, `float32`, `float64`, `bool`, `complex64`, `complex128` + - ⚠️ Avoid `float*`/`complex*` as keys due to NaN and precision semantics; prefer strings/ints +- **Pointer types**: Pointers to any type (e.g., `*int`, `*string`, `*MyStruct`) +- **Interface types**: Interface types are comparable only when their dynamic values are comparable types; using "==" with non-comparable dynamic values will be safely ignored with error log +- **Channel types**: Channel types (compared by channel identity) +- **Array types**: Arrays of comparable elements (e.g., `[3]int`, `[2]string`) +- **Struct types**: Structs where all fields are comparable types + +**⚠️ Non-comparable types (will be safely ignored with error log):** +- `map` types (e.g., `map[string]int`) +- `slice` types (e.g., `[]int`, `[]string`) +- `func` types (e.g., `func()`, `func(int) string`) +- Structs containing non-comparable fields (maps, slices, functions) + +**Examples:** +```go +// ✅ Valid usage +session.AddCloseCallback("user", "cleanup", callback) +session.AddCloseCallback(123, "cleanup", callback) +session.AddCloseCallback(true, false, callback) + +// ⚠️ Non-comparable types (safely ignored with error log) +session.AddCloseCallback(map[string]int{"a": 1}, "key", callback) // Logged and ignored +session.AddCloseCallback([]int{1, 2, 3}, "key", callback) // Logged and ignored +``` ## About network transmission in getty diff --git a/README_CN.md b/README_CN.md index ccc98d87..c88ab627 100644 --- a/README_CN.md +++ b/README_CN.md @@ -18,7 +18,73 @@ Getty 是一个使用 Golang 开发的异步网络 I/O 库。它适用于 TCP、 如果您使用 WebSocket,您无需担心心跳请求/响应,因为 Getty 在 session.go 的 (Session)handleLoop 方法内通过发送和接收 WebSocket ping/pong 帧来处理此任务。您只需在 codec.go 的 (Codec)OnCron 方法内使用 session.go 的 (Session)GetActive 方法检查 WebSocket 会话是否已超时。 -有关代码示例,请参阅 https://github.com/AlexStocks/getty-examples。 +有关代码示例,请参阅 [AlexStocks/getty-examples](https://github.com/AlexStocks/getty-examples)。 + +## 回调系统 + +Getty 提供了一个强大的回调系统,允许您为会话生命周期事件注册和管理回调函数。这对于清理操作、资源管理和自定义事件处理特别有用。 + +### 主要特性 + +- **线程安全操作**:所有回调操作都受到互斥锁保护 +- **替换语义**:使用相同的 (handler, key) 添加会替换现有回调并保持位置不变 +- **Panic 安全性**:在会话关闭期间,回调在专用 goroutine 中运行,带有 defer/recover;panic 会被记录堆栈跟踪且不会逃逸出关闭路径 +- **有序执行**:回调按照添加的顺序执行 + +### 使用示例 + +```go +// 添加关闭回调 +session.AddCloseCallback("cleanup", "resources", func() { + // 当会话关闭时清理资源 + cleanupResources() +}) + +// 移除特定回调 +// 即使从未添加过该对也可以安全调用(无操作) +session.RemoveCloseCallback("cleanup", "resources") + +// 当会话关闭时,回调会自动执行 +``` + +**注意**:在会话关闭期间,回调在专用 goroutine 中顺序执行以保持添加顺序,带有 defer/recover 来记录 panic 而不让它们逃逸出关闭路径。 + +### 回调管理 + +- **AddCloseCallback**:注册一个在会话关闭时执行的回调 +- **RemoveCloseCallback**:移除之前注册的回调(未找到时无操作;可安全多次调用) +- **线程安全**:所有操作都是线程安全的,可以并发调用 + +### 类型要求 + +`handler` 和 `key` 参数必须是**可比较的类型**,支持 `==` 操作符: + +**✅ 支持的类型:** +- **基本类型**:`string`、`int`、`int8`、`int16`、`int32`、`int64`、`uint`、`uint8`、`uint16`、`uint32`、`uint64`、`uintptr`、`float32`、`float64`、`bool`、`complex64`、`complex128` + - ⚠️ 避免使用 `float*`/`complex*` 作为键,因为 NaN 和精度语义问题;建议使用字符串/整数 +- **指针类型**:指向任何类型的指针(如 `*int`、`*string`、`*MyStruct`) +- **接口类型**:仅当其动态值为可比较类型时可比较;若动态值不可比较,使用"=="将被安全忽略并记录错误日志 +- **通道类型**:通道类型(按通道标识比较) +- **数组类型**:可比较元素的数组(如 `[3]int`、`[2]string`) +- **结构体类型**:所有字段都是可比较类型的结构体 + +**⚠️ 不可比较类型(将被安全忽略并记录错误日志):** +- `map` 类型(如 `map[string]int`) +- `slice` 类型(如 `[]int`、`[]string`) +- `func` 类型(如 `func()`、`func(int) string`) +- 包含不可比较字段的结构体(maps、slices、functions) + +**示例:** +```go +// ✅ 有效用法 +session.AddCloseCallback("user", "cleanup", callback) +session.AddCloseCallback(123, "cleanup", callback) +session.AddCloseCallback(true, false, callback) + +// ⚠️ 不可比较类型(安全忽略并记录错误日志) +session.AddCloseCallback(map[string]int{"a": 1}, "key", callback) // 记录日志并忽略 +session.AddCloseCallback([]int{1, 2, 3}, "key", callback) // 记录日志并忽略 +``` ## 关于 Getty 中的网络传输 diff --git a/transport/callback.go b/transport/callback.go new file mode 100644 index 00000000..4ea94c0d --- /dev/null +++ b/transport/callback.go @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package getty + +import ( + "fmt" + "reflect" +) + +import ( + perrors "github.com/pkg/errors" +) + +import ( + log "github.com/AlexStocks/getty/util" +) + +// callbackNode represents a node in the callback linked list +// Each node contains handler identifier, key, callback function and pointer to next node +type callbackNode struct { + handler any // Handler identifier, used to identify the source or type of callback + key any // Unique identifier key for callback, used in combination with handler + call func() // Actual callback function to be executed + next *callbackNode // Pointer to next node, forming linked list structure +} + +// callbacks is a singly linked list structure for managing multiple callback functions +// Supports dynamic addition, removal and execution of callbacks +type callbacks struct { + first *callbackNode // Pointer to the first node of the linked list + last *callbackNode // Pointer to the last node of the linked list, used for quick addition of new nodes + cbNum int // Number of callback functions in the linked list +} + +// isComparable checks if a value is comparable using Go's == operator +// Returns true if the value can be safely compared, false otherwise +func isComparable(v any) bool { + if v == nil { + return true + } + return reflect.TypeOf(v).Comparable() +} + +// Add adds a new callback function to the callback linked list +// Parameters: +// - handler: Handler identifier, can be any type +// - key: Unique identifier key for callback, used in combination with handler +// - callback: Callback function to be executed, ignored if nil +// +// Note: If a callback with the same handler and key already exists, it will be replaced +func (t *callbacks) Add(handler, key any, callback func()) { + // Prevent adding empty callback function + if callback == nil { + return + } + + // Guard: avoid runtime panic on non-comparable types + if !isComparable(handler) || !isComparable(key) { + log.Error(perrors.New(fmt.Sprintf("callbacks.Add: non-comparable handler/key: %T, %T; ignored", handler, key))) + return + } + + // Check if a callback with the same handler and key already exists + for cb := t.first; cb != nil; cb = cb.next { + if cb.handler == handler && cb.key == key { + // Replace existing callback + cb.call = callback + return + } + } + + // Create new callback node + newItem := &callbackNode{handler, key, callback, nil} + + if t.first == nil { + // If linked list is empty, new node becomes the first node + t.first = newItem + } else { + // Otherwise add new node to the end of linked list + t.last.next = newItem + } + // Update pointer to last node + t.last = newItem + // Increment callback count + t.cbNum++ +} + +// Remove removes the specified callback function from the callback linked list +// Parameters: +// - handler: Handler identifier of the callback to be removed +// - key: Unique identifier key of the callback to be removed +// +// Note: If no matching callback is found, this method has no effect +func (t *callbacks) Remove(handler, key any) { + // Guard: avoid runtime panic on non-comparable types + if !isComparable(handler) || !isComparable(key) { + log.Error(perrors.New(fmt.Sprintf("callbacks.Remove: non-comparable handler/key: %T, %T; ignored", handler, key))) + return + } + + var prev *callbackNode + + // Traverse linked list to find the node to be removed + for callback := t.first; callback != nil; prev, callback = callback, callback.next { + // Found matching node + if callback.handler == handler && callback.key == key { + if t.first == callback { + // If it's the first node, update first pointer + t.first = callback.next + } else if prev != nil { + // If it's a middle node, update the next pointer of the previous node + prev.next = callback.next + } + + if t.last == callback { + // If it's the last node, update last pointer + t.last = prev + } + + // Decrement callback count + t.cbNum-- + + // Return immediately after finding and removing + return + } + } +} + +// Invoke executes all registered callback functions in the linked list +// Executes each callback in the order they were added +// Note: If a callback function is nil, it will be skipped +// If a callback panics, it will be handled by the outer caller's panic recovery +func (t *callbacks) Invoke() { + // Traverse the entire linked list starting from the head node + for callback := t.first; callback != nil; callback = callback.next { + // Ensure callback function is not nil before executing + if callback.call != nil { + callback.call() + } + } +} + +// Len returns the number of callback functions in the linked list +// Return value: Total number of currently registered callback functions +func (t *callbacks) Len() int { + return t.cbNum +} diff --git a/transport/callback_test.go b/transport/callback_test.go new file mode 100644 index 00000000..e5ac3053 --- /dev/null +++ b/transport/callback_test.go @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package getty + +import ( + "testing" +) + +func TestCallback(t *testing.T) { + // Test empty list + cb := &callbacks{} + if cb.Len() != 0 { + t.Errorf("Expected count for empty list is 0, but got %d", cb.Len()) + } + + // Ensure invoking on an empty registry is a no-op (no panic). + cb.Invoke() + + // Test adding callback functions + var count, expected, remove, totalCount int + totalCount = 10 + remove = 5 + + // Add multiple callback functions + for i := 1; i < totalCount; i++ { + expected = expected + i + func(ii int) { + cb.Add(ii, ii, func() { count = count + ii }) + }(i) + } + + // Verify count after adding + expectedCallbacks := totalCount - 1 + if cb.Len() != expectedCallbacks { + t.Errorf("Expected callback count is %d, but got %d", expectedCallbacks, cb.Len()) + } + + // Test adding nil callback + cb.Add(remove, remove, nil) + if cb.Len() != expectedCallbacks { + t.Errorf("Expected count after adding nil callback is %d, but got %d", expectedCallbacks, cb.Len()) + } + + // Replace an existing callback with a non-nil one; count should remain unchanged. + cb.Add(remove, remove, func() { count += remove }) + if cb.Len() != expectedCallbacks { + t.Errorf("Expected count after replacing existing callback is %d, but got %d", expectedCallbacks, cb.Len()) + } + + // Remove specified callback + cb.Remove(remove, remove) + + // Try to remove non-existent callback + cb.Remove(remove+1, remove+2) + + // Execute all callbacks + cb.Invoke() + + // Verify execution result + expectedSum := expected - remove + if count != expectedSum { + t.Errorf("Expected execution result is %d, but got %d", expectedSum, count) + } + + // Test string type handler and key + cb2 := &callbacks{} + + // Add callbacks + cb2.Add("handler1", "key1", func() {}) + cb2.Add("handler2", "key2", func() {}) + cb2.Add("handler3", "key3", func() {}) + + if cb2.Len() != 3 { + t.Errorf("Expected callback count is 3, but got %d", cb2.Len()) + } + + // Remove middle callback + cb2.Remove("handler2", "key2") + if cb2.Len() != 2 { + t.Errorf("Expected count after removing middle callback is 2, but got %d", cb2.Len()) + } + + // Remove first callback + cb2.Remove("handler1", "key1") + if cb2.Len() != 1 { + t.Errorf("Expected count after removing first callback is 1, but got %d", cb2.Len()) + } + + // Remove last callback + cb2.Remove("handler3", "key3") + if cb2.Len() != 0 { + t.Errorf("Expected count after removing last callback is 0, but got %d", cb2.Len()) + } + + // Test removing non-existent callback + cb2.Add("handler1", "key1", func() {}) + cb2.Remove("handler2", "key2") // Try to remove non-existent callback + + // Should still have 1 callback + if cb2.Len() != 1 { + t.Errorf("Expected callback count is 1, but got %d", cb2.Len()) + } +} + +func TestCallbackInvokePanicPropagation(t *testing.T) { + cb := &callbacks{} + cb.Add("h", "k1", func() { panic("boom") }) + + // Test that panic is propagated (not swallowed by Invoke) + defer func() { + if r := recover(); r != nil { + if r != "boom" { + t.Errorf("Expected panic 'boom', got %v", r) + } + } else { + t.Errorf("Expected panic to be propagated, but it was swallowed") + } + }() + + // This should panic and be caught by the defer above + cb.Invoke() +} + +func TestCallbackNonComparableTypes(t *testing.T) { + cb := &callbacks{} + + // Test with non-comparable types (slice, map, function) + nonComparableTypes := []struct { + name string + handler any + key any + expected bool // whether the callback should be added + }{ + {"slice_handler", []int{1, 2, 3}, "key", false}, + {"map_handler", map[string]int{"a": 1}, "key", false}, + {"func_handler", func() {}, "key", false}, + {"slice_key", "handler", []int{1, 2, 3}, false}, + {"map_key", "handler", map[string]int{"a": 1}, false}, + {"func_key", "handler", func() {}, false}, + {"both_non_comparable", []int{1}, map[string]int{"a": 1}, false}, + {"comparable_types", "handler", "key", true}, + {"nil_values", nil, nil, true}, + {"mixed_comparable", "handler", 123, true}, + } + + for _, tt := range nonComparableTypes { + t.Run(tt.name, func(t *testing.T) { + initialCount := cb.Len() + + // Try to add callback + cb.Add(tt.handler, tt.key, func() {}) + + // Check if callback was added + finalCount := cb.Len() + if tt.expected { + if finalCount != initialCount+1 { + t.Errorf("Expected callback to be added, but count remained %d", initialCount) + } + // Clean up for next test + cb.Remove(tt.handler, tt.key) + } else { + if finalCount != initialCount { + t.Errorf("Expected callback to be ignored, but count changed from %d to %d", initialCount, finalCount) + } + } + }) + } + + // Test Remove with non-comparable types + t.Run("RemoveNonComparable", func(t *testing.T) { + initialCount := cb.Len() + + // Try to remove with non-comparable types + cb.Remove([]int{1, 2, 3}, map[string]int{"a": 1}) + + // Count should remain unchanged + if cb.Len() != initialCount { + t.Errorf("Expected count to remain %d after removing non-comparable types, but got %d", initialCount, cb.Len()) + } + }) +} diff --git a/transport/session.go b/transport/session.go index e5328dbb..61159d19 100644 --- a/transport/session.go +++ b/transport/session.go @@ -101,6 +101,9 @@ type Session interface { WriteBytes([]byte) (int, error) WriteBytesArray(...[]byte) (int, error) Close() + + AddCloseCallback(handler, key any, callback CallBackFunc) + RemoveCloseCallback(handler, key any) } // getty base session @@ -135,6 +138,10 @@ type session struct { grNum uatomic.Int32 lock sync.RWMutex packetLock sync.RWMutex + + // callbacks + closeCallback callbacks + closeCallbackMutex sync.RWMutex } func newSession(endPoint EndPoint, conn Connection) *session { @@ -861,6 +868,22 @@ func (s *session) stop() { } } close(s.done) + + go func(sessionToken string) { + defer func() { + if r := recover(); r != nil { + const size = 64 << 10 + rBuf := make([]byte, size) + rBuf = rBuf[:runtime.Stack(rBuf, false)] + err := perrors.WithStack(fmt.Errorf("[session.invokeCloseCallbacks] panic session %s: err=%v\n%s", + sessionToken, r, rBuf)) + log.Error(err) + } + }() + + s.invokeCloseCallbacks() + }(s.sessionToken()) + clt, cltFound := s.GetAttribute(sessionClientKey).(*client) ignoreReconnect, flagFound := s.GetAttribute(ignoreReconnectKey).(bool) if cltFound && flagFound && !ignoreReconnect { diff --git a/transport/session_callback.go b/transport/session_callback.go new file mode 100644 index 00000000..5fe91eba --- /dev/null +++ b/transport/session_callback.go @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package getty + +// AddCloseCallback adds a close callback function to the Session +// +// Parameters: +// - handler: handler identifier, used to identify the source or type of the callback +// - key: unique identifier key for the callback, used in combination with handler +// - f: callback function to be executed when the session is closed +// +// Notes: +// - If the session is already closed, this addition will be ignored +// - The combination of handler and key must be unique, otherwise it will override previous callbacks +// - Callback functions will be executed in the order they were added when the session closes +func (s *session) AddCloseCallback(handler, key any, f CallBackFunc) { + if f == nil { + return + } + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + if s.IsClosed() { + return + } + s.closeCallback.Add(handler, key, f) +} + +// RemoveCloseCallback removes the specified Session close callback function +// +// Parameters: +// - handler: handler identifier of the callback to be removed +// - key: unique identifier key of the callback to be removed +// +// Return value: none +// +// Notes: +// - If the session is already closed, this removal operation will be ignored +// - If no matching callback is found, this operation will have no effect +// - The removal operation is thread-safe +func (s *session) RemoveCloseCallback(handler, key any) { + s.closeCallbackMutex.Lock() + defer s.closeCallbackMutex.Unlock() + if s.IsClosed() { + return + } + s.closeCallback.Remove(handler, key) +} + +// invokeCloseCallbacks executes all registered close callback functions +// +// Function description: +// - Executes all registered close callbacks in the order they were added +// - Uses read lock to protect the callback list, ensuring concurrency safety +// - This method is typically called automatically when the session closes +// +// Notes: +// - This is an internal method, not recommended for external direct calls +// - If panic occurs during callback execution, it will be caught and logged +// - Callback functions should avoid long blocking operations, async processing is recommended for time-consuming tasks +func (s *session) invokeCloseCallbacks() { + s.closeCallbackMutex.RLock() + defer s.closeCallbackMutex.RUnlock() + s.closeCallback.Invoke() +} + +// CallBackFunc defines the callback function type when Session closes +// +// Usage notes: +// - Callback function accepts no parameters +// - Callback function returns no values +// - Callback function should handle resource cleanup, state updates, etc. +// - It's recommended to avoid accessing closed session state in callback functions +type CallBackFunc func() diff --git a/transport/session_callback_test.go b/transport/session_callback_test.go new file mode 100644 index 00000000..68228fb2 --- /dev/null +++ b/transport/session_callback_test.go @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package getty + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestSessionCallback(t *testing.T) { + // Test basic add, remove and execute callback functionality + t.Run("BasicCallback", func(t *testing.T) { + s := &session{ + once: &sync.Once{}, + done: make(chan struct{}), + closeCallback: callbacks{}, + } + + var callbackExecuted bool + var callbackMutex sync.Mutex + + callback := func() { + callbackMutex.Lock() + callbackExecuted = true + callbackMutex.Unlock() + } + + // Add callback + s.AddCloseCallback("testHandler", "testKey", callback) + if s.closeCallback.Len() != 1 { + t.Errorf("Expected callback count is 1, but got %d", s.closeCallback.Len()) + } + + // Test removing callback + s.RemoveCloseCallback("testHandler", "testKey") + if s.closeCallback.Len() != 0 { + t.Errorf("Expected callback count is 0, but got %d", s.closeCallback.Len()) + } + + // Re-add callback + s.AddCloseCallback("testHandler", "testKey", callback) + + // Test callback execution when closing + go func() { + time.Sleep(10 * time.Millisecond) + s.stop() + }() + + // Wait for callback execution + time.Sleep(50 * time.Millisecond) + + callbackMutex.Lock() + if !callbackExecuted { + t.Error("Callback function was not executed") + } + callbackMutex.Unlock() + }) + + // Test adding, removing and executing multiple callbacks + t.Run("MultipleCallbacks", func(t *testing.T) { + s := &session{ + once: &sync.Once{}, + done: make(chan struct{}), + closeCallback: callbacks{}, + } + + var callbackCount int + var callbackMutex sync.Mutex + + // Add multiple callbacks + totalCallbacks := 3 + for i := 0; i < totalCallbacks; i++ { + index := i // Capture loop variable + callback := func() { + callbackMutex.Lock() + callbackCount++ + callbackMutex.Unlock() + } + s.AddCloseCallback(fmt.Sprintf("handler%d", index), fmt.Sprintf("key%d", index), callback) + } + + if s.closeCallback.Len() != totalCallbacks { + t.Errorf("Expected callback count is %d, but got %d", totalCallbacks, s.closeCallback.Len()) + } + + // Remove one callback + s.RemoveCloseCallback("handler0", "key0") + expectedAfterRemove := totalCallbacks - 1 + if s.closeCallback.Len() != expectedAfterRemove { + t.Errorf("Expected callback count is %d, but got %d", expectedAfterRemove, s.closeCallback.Len()) + } + + // Test execution of remaining callbacks when closing + go func() { + time.Sleep(10 * time.Millisecond) + s.stop() + }() + + time.Sleep(50 * time.Millisecond) + + callbackMutex.Lock() + if callbackCount != expectedAfterRemove { + t.Errorf("Expected executed callback count is %d, but got %d", expectedAfterRemove, callbackCount) + } + callbackMutex.Unlock() + }) + + // Test invokeCloseCallbacks functionality + t.Run("InvokeCloseCallbacks", func(t *testing.T) { + s := &session{ + once: &sync.Once{}, + done: make(chan struct{}), + closeCallback: callbacks{}, + } + + var callbackResults []string + var callbackMutex sync.Mutex + + // Add multiple different types of close callbacks + callbacks := []struct { + handler string + key string + action string + }{ + {"cleanup", "resources", "Clean resources"}, + {"cleanup", "connections", "Close connections"}, + {"logging", "audit", "Log audit info"}, + {"metrics", "stats", "Update statistics"}, + } + + // Register all callbacks + for _, cb := range callbacks { + cbCopy := cb // Capture loop variable + callback := func() { + callbackMutex.Lock() + callbackResults = append(callbackResults, cbCopy.action) + callbackMutex.Unlock() + } + s.AddCloseCallback(cbCopy.handler, cbCopy.key, callback) + } + + // Verify callback count + expectedCount := len(callbacks) + if s.closeCallback.Len() != expectedCount { + t.Errorf("Expected callback count is %d, but got %d", expectedCount, s.closeCallback.Len()) + } + + // Manually invoke close callbacks (simulate invokeCloseCallbacks) + callbackMutex.Lock() + callbackResults = nil // Clear previous results + callbackMutex.Unlock() + + // Execute all close callbacks + s.closeCallback.Invoke() + + // Wait for callback execution to complete + time.Sleep(10 * time.Millisecond) + + // Verify all callbacks were executed + callbackMutex.Lock() + if len(callbackResults) != expectedCount { + t.Errorf("Expected to execute %d callbacks, but executed %d", expectedCount, len(callbackResults)) + } + + // Verify callback execution order (should execute in order of addition) + expectedActions := []string{"Clean resources", "Close connections", "Log audit info", "Update statistics"} + for i, result := range callbackResults { + if i < len(expectedActions) && result != expectedActions[i] { + t.Errorf("Position %d: Expected to execute '%s', but executed '%s'", i, expectedActions[i], result) + } + } + callbackMutex.Unlock() + + // Test execution after removing a callback + s.RemoveCloseCallback("cleanup", "resources") + + callbackMutex.Lock() + callbackResults = nil + callbackMutex.Unlock() + + // Execute callbacks again + s.closeCallback.Invoke() + time.Sleep(10 * time.Millisecond) + + // Verify execution results after removal + callbackMutex.Lock() + expectedAfterRemove := expectedCount - 1 + if len(callbackResults) != expectedAfterRemove { + t.Errorf("Expected to execute %d callbacks after removal, but executed %d", expectedAfterRemove, len(callbackResults)) + } + callbackMutex.Unlock() + }) + + // Test edge cases + t.Run("EdgeCases", func(t *testing.T) { + // Test empty callback list scenario + s := &session{ + once: &sync.Once{}, + done: make(chan struct{}), + closeCallback: callbacks{}, + } + + // Verify empty list + if s.closeCallback.Len() != 0 { + t.Errorf("Expected count for empty list is 0, but got %d", s.closeCallback.Len()) + } + + // Execute empty callback list (should not panic) + s.closeCallback.Invoke() + + // Add a callback then remove it, execute again + s.AddCloseCallback("test", "key", func() {}) + s.RemoveCloseCallback("test", "key") + + // Execute empty list after removal (should not panic) + s.closeCallback.Invoke() + + if s.closeCallback.Len() != 0 { + t.Errorf("Expected count after removal is 0, but got %d", s.closeCallback.Len()) + } + }) +}