From 0ef4a893654e786eab984286b89160d7709fdde7 Mon Sep 17 00:00:00 2001 From: jvjhfhg Date: Tue, 17 Mar 2026 20:04:39 +0800 Subject: [PATCH] Fix: refactor TensorMap rings and route cross-ring INOUT - refactor owner TensorMap storage into per-ring buckets, pools, and cleanup tracking - add a fallback tensormap for external tensors and cross-ring INOUT modifiers while keeping owner-ring history ring-local - route lookup and removal across owner and fallback sources and bind make_tensor() to the current scope ring - update paged attention to treat oi_batch as INOUT in the example and matching device test --- .../orchestration/paged_attention_orch.cpp | 2 +- .../orchestration/pto_orchestration_api.h | 70 ++- .../runtime/pto_orchestrator.cpp | 88 ++- .../runtime/pto_orchestrator.h | 3 + .../runtime/pto_runtime2_types.h | 4 + .../runtime/pto_tensormap.cpp | 357 +++++++---- .../runtime/pto_tensormap.h | 558 ++++++++++++------ .../tensormap_and_ringbuffer/runtime/tensor.h | 54 +- .../orchestration/paged_attention_orch.cpp | 2 +- 9 files changed, 807 insertions(+), 331 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp index 2c472f51..96ba0b4c 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp @@ -191,7 +191,7 @@ void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count, i params_up.add_input(oi_new_b); params_up.add_inout(mi_batch); params_up.add_inout(li_batch); - params_up.add_output(oi_batch); + params_up.add_inout(oi_batch); params_up.add_output(out); params_up.add_scalar(is_first); params_up.add_scalar(is_last); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h index ff7d2b18..560ae6b9 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h @@ -23,9 +23,18 @@ // Type headers needed by orchestration #include "pto_types.h" // PTOParam, PTOTensorEntry, PTOParamType -#include "tensor.h" // Tensor, make_tensor, make_tensor_external +#include "tensor.h" // Tensor struct #include "pto_submit_types.h" // MixedKernels, INVALID_KERNEL_ID, subtask slots +// Multi-ring: number of independent ring layers (HeapRing + TaskRing + DepPool per layer) +// Scope depth maps to ring index via: min(scope_depth, PTO2_MAX_RING_DEPTH - 1) +#define PTO2_MAX_RING_DEPTH 4 + +// Thread-local scope depth for tensor factory functions. +// Incremented/decremented by PTO2ScopeGuard and standalone scope wrappers. +// Tensor ring selection clamps this depth to the runtime's valid ring range. +static thread_local uint8_t __pto2_ring_id = 0; + // ============================================================================= // Ops Table and Opaque Runtime // ============================================================================= @@ -99,10 +108,12 @@ static inline void pto2_rt_submit_aiv_task(PTO2Runtime* rt, int32_t kernel_id, static inline void pto2_rt_scope_begin(PTO2Runtime* rt) { rt->ops->scope_begin(rt); + __pto2_ring_id++; } static inline void pto2_rt_scope_end(PTO2Runtime* rt) { rt->ops->scope_end(rt); + __pto2_ring_id--; } static inline void pto2_rt_orchestration_done(PTO2Runtime* rt) { @@ -113,6 +124,59 @@ static inline bool pto2_rt_is_fatal(PTO2Runtime* rt) { return rt->ops->is_fatal(rt); } +// ============================================================================= +// Tensor Factory Functions +// ============================================================================= + +/** + * Create a Tensor for pre-allocated external memory. + */ +static inline Tensor make_tensor_external(void* addr, + const uint32_t shapes[], + uint32_t ndims, + DataType dtype = DataType::FLOAT32, + bool manual_dep = false, + int32_t version = 0) { + static uint32_t zero_offsets[RUNTIME_MAX_TENSOR_DIMS] = {}; + uint64_t total = 1; + for (uint32_t i = 0; i < ndims; i++) { + total *= shapes[i]; + } + return Tensor(addr, total * get_element_size(dtype), shapes, shapes, zero_offsets, ndims, dtype, version, + /*is_all_offset_zero=*/true, /*is_raw_eq_shapes=*/true, manual_dep, + TENSOR_RING_ID_NONE); +} + +static inline Tensor make_tensor_with_ring(const uint32_t shapes[], + uint32_t ndims, + DataType dtype, + bool manual_dep, + int32_t version, + uint8_t ring_id) { + static uint32_t zero_offsets[RUNTIME_MAX_TENSOR_DIMS] = {}; + uint64_t total = 1; + for (uint32_t i = 0; i < ndims; i++) { + total *= shapes[i]; + } + return Tensor(0, total * get_element_size(dtype), shapes, shapes, zero_offsets, ndims, dtype, version, + /*is_all_offset_zero=*/true, /*is_raw_eq_shapes=*/true, manual_dep, ring_id); +} + +static inline uint8_t current_tensor_ring_id() { + return __pto2_ring_id < PTO2_MAX_RING_DEPTH ? __pto2_ring_id : PTO2_MAX_RING_DEPTH - 1; +} + +/** + * Create a Tensor for runtime-allocated output (addr=0). + * Uses the thread-local scope depth set by PTO2ScopeGuard, clamped to the + * runtime ring range to match PTO2OrchestratorState::current_ring_id(). + */ +static inline Tensor make_tensor(const uint32_t shapes[], uint32_t ndims, + DataType dtype = DataType::FLOAT32, bool manual_dep = false, + int32_t version = 0) { + return make_tensor_with_ring(shapes, ndims, dtype, manual_dep, version, current_tensor_ring_id()); +} + // ============================================================================= // Logging Macros for Orchestration (call through ops table) // ============================================================================= @@ -133,10 +197,10 @@ static inline bool pto2_rt_is_fatal(PTO2Runtime* rt) { class PTO2ScopeGuard { public: PTO2ScopeGuard(PTO2Runtime* rt) : rt_(rt) { - rt_->ops->scope_begin(rt_); + pto2_rt_scope_begin(rt_); } ~PTO2ScopeGuard() { - rt_->ops->scope_end(rt_); + pto2_rt_scope_end(rt_); } private: PTO2Runtime* rt_; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp index d62e0f9a..d1834369 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp @@ -161,9 +161,16 @@ bool pto2_orchestrator_init( int32_t init_cap = PTO2_SCOPE_TASKS_INIT_CAP; orch->scope_tasks = (PTO2TaskSlotState**)malloc(init_cap * sizeof(PTO2TaskSlotState*)); orch->scope_begins = (int32_t*)malloc(max_depth * sizeof(int32_t)); - if (!orch->scope_tasks || !orch->scope_begins) { + orch->scope_escape_tasks = (PTO2TaskSlotState***)calloc(max_depth, sizeof(PTO2TaskSlotState**)); + orch->scope_escape_counts = (int32_t*)calloc(max_depth, sizeof(int32_t)); + orch->scope_escape_capacities = (int32_t*)calloc(max_depth, sizeof(int32_t)); + if (!orch->scope_tasks || !orch->scope_begins || !orch->scope_escape_tasks || + !orch->scope_escape_counts || !orch->scope_escape_capacities) { free(orch->scope_tasks); free(orch->scope_begins); + free(orch->scope_escape_tasks); + free(orch->scope_escape_counts); + free(orch->scope_escape_capacities); for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { free(orch->rings[r].dep_pool.base); } @@ -190,6 +197,17 @@ void pto2_orchestrator_destroy(PTO2OrchestratorState* orch) { orch->scope_tasks = NULL; free(orch->scope_begins); orch->scope_begins = NULL; + if (orch->scope_escape_tasks) { + for (uint64_t i = 0; i < orch->scope_stack_capacity; ++i) { + free(orch->scope_escape_tasks[i]); + } + free(orch->scope_escape_tasks); + orch->scope_escape_tasks = NULL; + } + free(orch->scope_escape_counts); + orch->scope_escape_counts = NULL; + free(orch->scope_escape_capacities); + orch->scope_escape_capacities = NULL; } void pto2_orchestrator_set_scheduler(PTO2OrchestratorState* orch, PTO2SchedulerState* scheduler) { @@ -211,12 +229,29 @@ static void scope_tasks_push(PTO2OrchestratorState* orch, PTO2TaskSlotState *tas orch->scope_tasks[orch->scope_tasks_size++] = task_slot_state; } +static void scope_escape_tasks_push(PTO2OrchestratorState* orch, int32_t scope_idx, + PTO2TaskSlotState* task_slot_state) { + always_assert(scope_idx >= 0 && static_cast(scope_idx) < orch->scope_stack_capacity); + int32_t count = orch->scope_escape_counts[scope_idx]; + int32_t capacity = orch->scope_escape_capacities[scope_idx]; + if (count >= capacity) { + int32_t new_cap = capacity > 0 ? capacity * 2 : PTO2_SCOPE_TASKS_INIT_CAP; + PTO2TaskSlotState** new_buf = (PTO2TaskSlotState**)realloc( + orch->scope_escape_tasks[scope_idx], new_cap * sizeof(PTO2TaskSlotState*)); + assert(new_buf && "Failed to grow scope escape task buffer"); + orch->scope_escape_tasks[scope_idx] = new_buf; + orch->scope_escape_capacities[scope_idx] = new_cap; + } + orch->scope_escape_tasks[scope_idx][orch->scope_escape_counts[scope_idx]++] = task_slot_state; +} + void pto2_scope_begin(PTO2OrchestratorState* orch) { if (orch->fatal) { return; } assert(orch->scope_stack_top < (int32_t)(orch->scope_stack_capacity - 1) && "Scope stack overflow"); ++orch->scope_stack_top; orch->scope_begins[orch->scope_stack_top] = orch->scope_tasks_size; + orch->scope_escape_counts[orch->scope_stack_top] = 0; } void pto2_scope_end(PTO2OrchestratorState* orch) { @@ -227,12 +262,18 @@ void pto2_scope_end(PTO2OrchestratorState* orch) { uint64_t _se0 = get_sys_cnt_aicpu(); #endif + int32_t scope_idx = orch->scope_stack_top; int32_t begin = orch->scope_begins[orch->scope_stack_top--]; int32_t count = orch->scope_tasks_size - begin; if (orch->scheduler && count > 0) { orch->scheduler->on_scope_end(&orch->scope_tasks[begin], count); } + if (orch->scheduler && orch->scope_escape_counts[scope_idx] > 0) { + orch->scheduler->on_scope_end( + orch->scope_escape_tasks[scope_idx], orch->scope_escape_counts[scope_idx]); + } + orch->scope_escape_counts[scope_idx] = 0; // Rewind the task buffer — these entries are no longer needed orch->scope_tasks_size = begin; @@ -271,7 +312,7 @@ void pto2_submit_mixed_task( return; } - + // Determine which ring this task belongs to uint8_t ring_id = orch->current_ring_id(); auto& task_ring = orch->rings[ring_id].task_ring; @@ -343,7 +384,23 @@ void pto2_submit_mixed_task( PTO2TaskId mixed_task_id = pto2_make_task_id(ring_id, static_cast(local_id)); PTO2TaskDescriptor& task = task_ring.get_task_by_slot(slot); - PTO2TaskPayload* payload = &orch->sm_handle->task_payloads[ring_id][slot]; + PTO2TaskPayload* payload = &orch->sm_handle->task_payloads[ring_id][slot]; + bool extra_scope_refs[PTO2_MAX_RING_DEPTH] = {}; + int32_t extra_scope_ref_count = 0; + for (int i = 0; i < params.tensor_count; i++) { + if (params.tensor_types[i] != PTOParamType::INOUT || params.tensors[i]->manual_dep) { + continue; + } + uint8_t owner_ring = params.tensors[i]->ring_id; + if (owner_ring == TENSOR_RING_ID_NONE || owner_ring >= PTO2_MAX_RING_DEPTH || owner_ring >= ring_id) { + continue; + } + if (!extra_scope_refs[owner_ring]) { + always_assert(static_cast(owner_ring) <= orch->scope_stack_top); + extra_scope_refs[owner_ring] = true; + extra_scope_ref_count++; + } + } // Early write-prefetch payload GM cache lines to issue RFO in background. // ~130 lines of computation (output_size, lookup, insert) follow before @@ -367,8 +424,9 @@ void pto2_submit_mixed_task( slot_state.fanin_count = 0; slot_state.fanout_head = nullptr; slot_state.fanout_lock.store(0, std::memory_order_relaxed); - // Initial fanout_count = 1 (the owning scope holds one reference) - slot_state.fanout_count = 1; + // Direct owner scope always holds one reference. Cross-scope INOUT + // additionally pins the task to each escaped owner scope. + slot_state.fanout_count = 1 + extra_scope_ref_count; slot_state.fanout_refcount.store(0, std::memory_order_release); slot_state.fanin_refcount.store(0, std::memory_order_release); slot_state.payload = payload; @@ -377,6 +435,11 @@ void pto2_submit_mixed_task( slot_state.subtask_done_mask.store(0, std::memory_order_relaxed); slot_state.ring_id = ring_id; scope_tasks_push(orch, &slot_state); + for (int32_t owner_scope = 0; owner_scope < PTO2_MAX_RING_DEPTH; ++owner_scope) { + if (extra_scope_refs[owner_scope]) { + scope_escape_tasks_push(orch, owner_scope, &slot_state); + } + } } else { scope_tasks_push(orch, nullptr); } @@ -416,7 +479,7 @@ void pto2_submit_mixed_task( // Read current last_task_alive from shared memory for this ring int32_t sm_last_task_alive = fc.last_task_alive.load(std::memory_order_acquire); - orch->tensor_map.sync_tensormap(ring_id, sm_last_task_alive); + orch->tensor_map.sync_tensormap(); if (sched) { orch->rings[ring_id].dep_pool.reclaim(*sched, ring_id, sm_last_task_alive); @@ -474,11 +537,20 @@ void pto2_submit_mixed_task( case PTOParamType::OUTPUT: { Tensor& tensor = *params.tensors[i]; - if (tensor.buffer.addr == 0) { + bool needs_alloc = tensor.buffer.addr == 0; + if (needs_alloc) { uint64_t alloc_addr = reinterpret_cast((char*)local_packed_base + offset); tensor.buffer.addr = alloc_addr; offset += PTO2_ALIGN_UP(tensor.buffer.size, PTO2_PACKED_OUTPUT_ALIGN); } + if (tensor.ring_id == TENSOR_RING_ID_NONE) { + always_assert(!needs_alloc && + "Internal OUTPUT tensor must have ring_id assigned before submit"); + } else { + always_assert(tensor.ring_id < PTO2_MAX_RING_DEPTH); + always_assert(static_cast(tensor.ring_id) == ring_id && + "OUTPUT tensor ring_id must match submit ring"); + } break; } } @@ -491,7 +563,7 @@ void pto2_submit_mixed_task( PTOParamType ptype = params.tensor_types[i]; if (ptype == PTOParamType::OUTPUT || ptype == PTOParamType::INOUT) { if (!params.tensors[i]->manual_dep) { - orch->tensor_map.insert(*params.tensors[i], mixed_task_id, ptype == PTOParamType::OUTPUT); + orch->tensor_map.insert(*params.tensors[i], mixed_task_id, ptype == PTOParamType::OUTPUT, ring_id); } } } diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h index a2d4898d..bbc972a0 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h @@ -53,6 +53,9 @@ struct PTO2OrchestratorState { int32_t scope_tasks_size; // Number of task IDs currently in the buffer int32_t scope_tasks_capacity; // Allocated capacity of scope_tasks int32_t* scope_begins; // scope_begins[i] = start index of scope i in scope_tasks + PTO2TaskSlotState*** scope_escape_tasks; // Cross-scope INOUT refs held by owner scope + int32_t* scope_escape_counts; // Number of escape refs per scope + int32_t* scope_escape_capacities; // Allocated capacity per scope int32_t scope_stack_top; // Current top of stack (-1 = no scope open) uint64_t scope_stack_capacity; // Max nesting depth (PTO2_MAX_SCOPE_DEPTH) diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h index 141be544..e06a74a9 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h @@ -89,6 +89,10 @@ #define PTO2_TENSORMAP_POOL_SIZE (65536) // TensorMap entry pool #define PTO2_TENSORMAP_NUM_BUCKETS 65536 // Power of 2 for fast hash +// Fallback TensorMap (for cross-ring INOUT and external tensor entries) +#define PTO2_FALLBACK_POOL_SIZE 4096 // Fallback entry pool (rare path) +#define PTO2_FALLBACK_NUM_BUCKETS 4096 // Power of 2 for fast hash + // Scope management #define PTO2_MAX_SCOPE_DEPTH 64 // Maximum nesting depth #define PTO2_SCOPE_TASKS_INIT_CAP 65536 // Initial capacity for scope task buffer diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp index 50cd57b1..387ce644 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp @@ -1,14 +1,8 @@ /** * PTO Runtime2 - TensorMap Implementation * - * Implements TensorMap with ring buffer pool, lazy invalidation, - * and chain truncation optimization. - * - * Key features: - * 1. O(1) insert at bucket head - * 2. O(valid_entries) lookup with chain truncation - * 3. Automatic stale entry cleanup during lookup - * 4. Periodic explicit cleanup for long chains + * Implements per-ring TensorMap with independent buckets, pools, and + * break-on-stale optimization. * * Based on: docs/runtime_buffer_manager_methods.md */ @@ -35,44 +29,32 @@ uint64_t g_insert_count = 0; #endif // ============================================================================= -// Initialization and Destruction +// PTO2TensorMapRing — Initialization and Destruction // ============================================================================= -bool PTO2TensorMap::init(int32_t new_num_buckets, int32_t new_pool_size, const int32_t new_task_window_sizes[PTO2_MAX_RING_DEPTH]) { - // Validate power of 2 for fast modulo +bool PTO2TensorMapRing::init(int32_t new_num_buckets, int32_t new_pool_size, int32_t new_task_window_size) { if ((new_num_buckets & (new_num_buckets - 1)) != 0) { - return false; // num_buckets must be power of 2 - } - - // Allocate buckets - buckets = (PTO2TensorMapEntry**)malloc(new_num_buckets * sizeof(PTO2TensorMapEntry*)); - if (!buckets) { return false; } - // Initialize all buckets to empty (-1) + buckets = (PTO2TensorMapEntry**)malloc(new_num_buckets * sizeof(PTO2TensorMapEntry*)); + if (!buckets) return false; for (int32_t i = 0; i < new_num_buckets; i++) { buckets[i] = nullptr; } - num_buckets = new_num_buckets; - // Allocate entry pool (64-byte aligned for cache-line-aligned entries) entry_pool = (PTO2TensorMapEntry*)aligned_alloc(alignof(PTO2TensorMapEntry), new_pool_size * sizeof(PTO2TensorMapEntry)); if (!entry_pool) { - free(buckets); - buckets = NULL; + free(buckets); buckets = nullptr; return false; } memset(entry_pool, 0, new_pool_size * sizeof(PTO2TensorMapEntry)); - // Allocate free entry list free_entry_list = (PTO2TensorMapEntry**)calloc(new_pool_size, sizeof(PTO2TensorMapEntry*)); if (!free_entry_list) { - free(buckets); - free(entry_pool); - buckets = NULL; - entry_pool = NULL; + free(buckets); buckets = nullptr; + free(entry_pool); entry_pool = nullptr; return false; } @@ -80,7 +62,6 @@ bool PTO2TensorMap::init(int32_t new_num_buckets, int32_t new_pool_size, const i next_entry_idx = 0; free_num = 0; - // Initialize all entries as not in bucket for (int32_t i = 0; i < pool_size; i++) { entry_pool[i].bucket_index = -1; entry_pool[i].next_in_bucket = nullptr; @@ -90,70 +71,174 @@ bool PTO2TensorMap::init(int32_t new_num_buckets, int32_t new_pool_size, const i entry_pool[i].producer_task_id = PTO2TaskId{}; } - // Allocate per-ring per-task entry tracking (each ring has its own window size) - for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { - task_entry_heads[r] = (PTO2TensorMapEntry**)malloc(new_task_window_sizes[r] * sizeof(PTO2TensorMapEntry*)); - if (!task_entry_heads[r]) { - // Cleanup previously allocated rings - for (int j = 0; j < r; j++) { - free(task_entry_heads[j]); - task_entry_heads[j] = NULL; + task_entry_heads = (PTO2TensorMapEntry**)malloc(new_task_window_size * sizeof(PTO2TensorMapEntry*)); + if (!task_entry_heads) { + free(buckets); buckets = nullptr; + free(entry_pool); entry_pool = nullptr; + free(free_entry_list); free_entry_list = nullptr; + return false; + } + for (int32_t i = 0; i < new_task_window_size; i++) { + task_entry_heads[i] = nullptr; + } + task_window_size = new_task_window_size; + + last_task_alive = 0; + last_cleanup = 0; + + return true; +} + +void PTO2TensorMapRing::destroy() { + if (buckets) { free(buckets); buckets = nullptr; } + if (entry_pool) { free(entry_pool); entry_pool = nullptr; } + if (free_entry_list) { free(free_entry_list); free_entry_list = nullptr; } + if (task_entry_heads) { free(task_entry_heads); task_entry_heads = nullptr; } +} + +// ============================================================================= +// PTO2TensorMapRing — Debug Utilities +// ============================================================================= + +void PTO2TensorMapRing::print_stats() { + int32_t valid = 0; + int32_t stale = 0; + int32_t empty_buckets = 0; + int32_t max_chain = 0; + int64_t total_chain = 0; + int32_t non_empty_buckets = 0; + + for (int32_t i = 0; i < pool_size; i++) { + if (entry_pool[i].bucket_index != -1) { + if (entry_valid(entry_pool[i])) { + valid++; + } else { + stale++; } - free(entry_pool); - free(buckets); - free(free_entry_list); - entry_pool = NULL; - buckets = NULL; - free_entry_list = NULL; - return false; } - for (int32_t i = 0; i < new_task_window_sizes[r]; i++) { - task_entry_heads[r][i] = nullptr; - } - task_window_sizes[r] = new_task_window_sizes[r]; } - for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { - last_task_alives[r] = 0; - last_cleanup[r] = 0; + for (int32_t b = 0; b < num_buckets; b++) { + int32_t chain_len = 0; + auto cur_entry = buckets[b]; + while (cur_entry != nullptr) { + chain_len++; + cur_entry = cur_entry->next_in_bucket; + } + if (chain_len == 0) { + empty_buckets++; + } else { + non_empty_buckets++; + total_chain += chain_len; + if (chain_len > max_chain) max_chain = chain_len; + } } - return true; + LOG_INFO(" Pool size: %d", pool_size); + LOG_INFO(" Pool next entry idx: %d", next_entry_idx); + LOG_INFO(" Pool free_num: %d", free_num); + LOG_INFO(" Num buckets: %d", num_buckets); + LOG_INFO(" Valid entries: %d", valid); + LOG_INFO(" Stale entries: %d", stale); + LOG_INFO(" Empty buckets: %d", empty_buckets); + LOG_INFO(" Max chain len: %d", max_chain); + LOG_INFO(" Avg chain len: %.2f", non_empty_buckets > 0 ? (float)total_chain / non_empty_buckets : 0); + LOG_INFO(" Last task alive: %d", last_task_alive); + LOG_INFO(" Last cleanup: %d", last_cleanup); } -bool PTO2TensorMap::init_default(const int32_t new_task_window_sizes[PTO2_MAX_RING_DEPTH]) { - return init(PTO2_TENSORMAP_NUM_BUCKETS, PTO2_TENSORMAP_POOL_SIZE, new_task_window_sizes); +int32_t PTO2TensorMapRing::valid_count() { + int32_t count = 0; + for (int32_t i = 0; i < pool_size; i++) { + if (entry_pool[i].bucket_index != -1 && entry_valid(entry_pool[i])) { + count++; + } + } + return count; } -void PTO2TensorMap::destroy() { - if (buckets) { - free(buckets); - buckets = NULL; +// ============================================================================= +// PTO2FallbackTensorMap — Initialization and Destruction +// ============================================================================= + +bool PTO2FallbackTensorMap::init(int32_t new_num_buckets, int32_t new_pool_size) { + if ((new_num_buckets & (new_num_buckets - 1)) != 0) { + return false; } - if (entry_pool) { - free(entry_pool); - entry_pool = NULL; + buckets = (PTO2TensorMapEntry**)malloc(new_num_buckets * sizeof(PTO2TensorMapEntry*)); + if (!buckets) return false; + for (int32_t i = 0; i < new_num_buckets; ++i) { + buckets[i] = nullptr; } + num_buckets = new_num_buckets; - for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { - if (task_entry_heads[r]) { - free(task_entry_heads[r]); - task_entry_heads[r] = NULL; - } + entry_pool = (PTO2TensorMapEntry*)aligned_alloc(alignof(PTO2TensorMapEntry), + new_pool_size * sizeof(PTO2TensorMapEntry)); + if (!entry_pool) { + free(buckets); buckets = nullptr; + return false; + } + memset(entry_pool, 0, new_pool_size * sizeof(PTO2TensorMapEntry)); + + free_entry_list = (PTO2TensorMapEntry**)calloc(new_pool_size, sizeof(PTO2TensorMapEntry*)); + if (!free_entry_list) { + free(buckets); buckets = nullptr; + free(entry_pool); entry_pool = nullptr; + return false; } - if (free_entry_list) { - free(free_entry_list); - free_entry_list = NULL; + pool_size = new_pool_size; + next_entry_idx = 0; + free_num = 0; + lifecycle_window_size = new_pool_size; + next_lifecycle_key = 0; + last_task_alive = 0; + last_insert_task_raw = UINT64_MAX; + last_insert_lifecycle_key = -1; + for (int32_t i = 0; i < pool_size; ++i) { + entry_pool[i].bucket_index = -1; + entry_pool[i].next_in_bucket = nullptr; + entry_pool[i].prev_in_bucket = nullptr; + entry_pool[i].next_in_task = nullptr; + entry_pool[i].prev_in_task = nullptr; + entry_pool[i].producer_task_id = PTO2TaskId{}; } + + lifecycle_entry_heads = + (PTO2TensorMapEntry**)calloc(lifecycle_window_size, sizeof(PTO2TensorMapEntry*)); + lifecycle_producer_ids = (PTO2TaskId*)calloc(lifecycle_window_size, sizeof(PTO2TaskId)); + lifecycle_entry_keys = (int32_t*)malloc(lifecycle_window_size * sizeof(int32_t)); + if (!lifecycle_entry_heads || !lifecycle_producer_ids || !lifecycle_entry_keys) { + free(lifecycle_entry_heads); lifecycle_entry_heads = nullptr; + free(lifecycle_producer_ids); lifecycle_producer_ids = nullptr; + free(lifecycle_entry_keys); lifecycle_entry_keys = nullptr; + free(free_entry_list); free_entry_list = nullptr; + free(entry_pool); entry_pool = nullptr; + free(buckets); buckets = nullptr; + return false; + } + for (int32_t i = 0; i < lifecycle_window_size; ++i) { + lifecycle_entry_keys[i] = -1; + } + + return true; +} + +void PTO2FallbackTensorMap::destroy() { + if (buckets) { free(buckets); buckets = nullptr; } + if (entry_pool) { free(entry_pool); entry_pool = nullptr; } + if (free_entry_list) { free(free_entry_list); free_entry_list = nullptr; } + if (lifecycle_entry_heads) { free(lifecycle_entry_heads); lifecycle_entry_heads = nullptr; } + if (lifecycle_producer_ids) { free(lifecycle_producer_ids); lifecycle_producer_ids = nullptr; } + if (lifecycle_entry_keys) { free(lifecycle_entry_keys); lifecycle_entry_keys = nullptr; } } // ============================================================================= -// Debug Utilities +// PTO2FallbackTensorMap — Debug Utilities // ============================================================================= -void PTO2TensorMap::print_stats() { +void PTO2FallbackTensorMap::print_stats(const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]) { int32_t valid = 0; int32_t stale = 0; int32_t empty_buckets = 0; @@ -161,10 +246,9 @@ void PTO2TensorMap::print_stats() { int64_t total_chain = 0; int32_t non_empty_buckets = 0; - // Count entries - for (int32_t i = 0; i < pool_size; i++) { + for (int32_t i = 0; i < pool_size; ++i) { if (entry_pool[i].bucket_index != -1) { - if (entry_valid(entry_pool[i])) { + if (producer_alive(entry_pool[i], owner_rings)) { valid++; } else { stale++; @@ -172,62 +256,129 @@ void PTO2TensorMap::print_stats() { } } - // Count bucket stats - for (int32_t b = 0; b < num_buckets; b++) { + for (int32_t b = 0; b < num_buckets; ++b) { int32_t chain_len = 0; auto cur_entry = buckets[b]; - while (cur_entry != nullptr) { chain_len++; cur_entry = cur_entry->next_in_bucket; } - if (chain_len == 0) { empty_buckets++; } else { non_empty_buckets++; total_chain += chain_len; - if (chain_len > max_chain) { - max_chain = chain_len; + if (chain_len > max_chain) max_chain = chain_len; + } + } + + LOG_INFO(" Pool size: %d", pool_size); + LOG_INFO(" Pool next entry idx: %d", next_entry_idx); + LOG_INFO(" Pool free_num: %d", free_num); + LOG_INFO(" Num buckets: %d", num_buckets); + LOG_INFO(" Valid entries: %d", valid); + LOG_INFO(" Stale entries: %d", stale); + LOG_INFO(" Empty buckets: %d", empty_buckets); + LOG_INFO(" Max chain len: %d", max_chain); + LOG_INFO(" Avg chain len: %.2f", non_empty_buckets > 0 ? (float)total_chain / non_empty_buckets : 0); + LOG_INFO(" Lifecycle next key: %d", next_lifecycle_key); + LOG_INFO(" Fallback last alive: %d", last_task_alive); +} + +int32_t PTO2FallbackTensorMap::valid_count(const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]) { + int32_t count = 0; + for (int32_t i = 0; i < pool_size; ++i) { + if (entry_pool[i].bucket_index != -1 && producer_alive(entry_pool[i], owner_rings)) { + count++; + } + } + return count; +} + +// ============================================================================= +// PTO2TensorMap — Initialization and Destruction +// ============================================================================= + +bool PTO2TensorMap::init(int32_t new_num_buckets, int32_t new_pool_size, const int32_t new_task_window_sizes[PTO2_MAX_RING_DEPTH]) { + for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { + if (!rings[r].init(new_num_buckets, new_pool_size, new_task_window_sizes[r])) { + for (int j = 0; j < r; j++) { + rings[j].destroy(); } + return false; + } + } + if (!fallback_map.init(PTO2_FALLBACK_NUM_BUCKETS, PTO2_FALLBACK_POOL_SIZE)) { + for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { + rings[r].destroy(); } + return false; + } + return true; +} + +bool PTO2TensorMap::init_default(const int32_t new_task_window_sizes[PTO2_MAX_RING_DEPTH]) { + return init(PTO2_TENSORMAP_NUM_BUCKETS, PTO2_TENSORMAP_POOL_SIZE, new_task_window_sizes); +} + +void PTO2TensorMap::destroy() { + for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { + rings[r].destroy(); } + fallback_map.destroy(); +} + +// ============================================================================= +// PTO2TensorMap — Debug Utilities +// ============================================================================= +void PTO2TensorMap::print_stats() { LOG_INFO("=== TensorMap Statistics ==="); - LOG_INFO("Pool size: %d", pool_size); - LOG_INFO("Pool next entry idx: %d", next_entry_idx); - LOG_INFO("Pool free_num: %d", free_num); - LOG_INFO("Num buckets: %d", num_buckets); - LOG_INFO("Valid entries: %d", valid); - LOG_INFO("Stale entries: %d", stale); - LOG_INFO("Empty buckets: %d", empty_buckets); - LOG_INFO("Max chain len: %d", max_chain); - LOG_INFO("Avg chain len: %.2f", non_empty_buckets > 0 ? (float)total_chain / non_empty_buckets : 0); for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { - LOG_INFO("Last task alive[%d]: %d", r, last_task_alives[r]); + LOG_INFO("--- Ring %d ---", r); + rings[r].print_stats(); } + LOG_INFO("--- Fallback ---"); + fallback_map.print_stats(rings); LOG_INFO("============================"); } int32_t PTO2TensorMap::valid_count() { - int32_t count = 0; + int32_t total = 0; + for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { + total += rings[r].valid_count(); + } + total += fallback_map.valid_count(rings); + return total; +} - for (int32_t i = 0; i < pool_size; i++) { - if (entry_pool[i].bucket_index != -1 && entry_valid(entry_pool[i])) { - count++; +void PTO2TensorMap::sync_tensormap() { + constexpr int MIN_FREE_NUM = 1024; + always_assert(orch != nullptr); + while (true) { + bool did_cleanup = false; + + // Sync owner rings (must happen before fallback so last_task_alive is fresh) + for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) { + int32_t new_last_task_alive = + orch->sm_handle->header->rings[r].fc.last_task_alive.load(std::memory_order_acquire); + rings[r].sync_validity(new_last_task_alive); + if (new_last_task_alive <= rings[r].last_cleanup) continue; + if ((rings[r].pool_size - rings[r].next_entry_idx + rings[r].free_num < MIN_FREE_NUM) || + new_last_task_alive - rings[r].last_cleanup >= PTO2_TENSORMAP_CLEANUP_INTERVAL) { + rings[r].cleanup_retired(rings[r].last_cleanup, new_last_task_alive); + rings[r].last_cleanup = new_last_task_alive; + did_cleanup = true; + } } - } - return count; -} + int32_t prev_fallback_last_alive = fallback_map.last_task_alive; + fallback_map.advance_frontier(rings); + if (fallback_map.last_task_alive != prev_fallback_last_alive) { + did_cleanup = true; + } -void PTO2TensorMap::sync_tensormap(uint8_t ring_id, int32_t sm_last_task_alive) { - sync_validity(ring_id, sm_last_task_alive); - // Only attempt cleanup when last_task_alive has actually advanced; - // otherwise cleanup_retired would empty-loop and we'd spin forever. - if (sm_last_task_alive - last_cleanup[ring_id] >= PTO2_TENSORMAP_CLEANUP_INTERVAL) { - cleanup_retired(ring_id, last_cleanup[ring_id], sm_last_task_alive); - last_cleanup[ring_id] = sm_last_task_alive; + if (!did_cleanup) break; } } diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h index 9d1bb56a..41c58d63 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h @@ -6,13 +6,15 @@ * - Used by pto_submit_task() to find dependencies * * Key design features: - * 1. Ring buffer pool for entries (no malloc/free) + * 1. Per-ring buckets, pools, and free lists for full isolation * 2. Lazy invalidation (entries become stale when producer retires) - * 3. Per-task per-ring entry tracking for efficient cleanup + * 3. Per-task entry tracking for efficient cleanup * 4. OVERLAP DETECTION: Detects dependencies for overlapping sub-regions + * 5. Break-on-stale: per-ring chains are ordered newest-first, so + * encountering a stale entry means all subsequent entries are also stale * * Hash table with chaining: - * - buckets[] array of head offsets + * - buckets[] array of head pointers * - Entries linked via next_in_bucket * - Insert at head (newest first) for sorted chains * @@ -54,7 +56,7 @@ extern uint64_t g_insert_count; #endif // ============================================================================= -// TensorMap Structure +// TensorMap Entry Structure // ============================================================================= /** @@ -82,9 +84,10 @@ struct alignas(64) PTO2TensorMapEntry { int32_t bucket_index; // 4B: bucket index (-1 if unlinked) bool is_all_offset_zero; // 1B: fast-path flag bool with_alloc; // 1B: true=OUTPUT, false=INOUT - // padding: 2B + uint8_t owner_ring; // 1B: which ring/map owns this entry (TENSOR_RING_ID_NONE = fallback) + // padding: 1B uint32_t shapes[RUNTIME_MAX_TENSOR_DIMS]; // 20B: shape per dimension - // padding: 4B to fill 64B + int32_t lifecycle_key; // 4B: producer local_id in the cleanup lifecycle domain // === Cache line 2 (64B) — insert/remove/slow-path === PTO2TensorMapEntry* prev_in_bucket; // 8B: prev in hash bucket chain @@ -154,7 +157,7 @@ static_assert(sizeof(PTO2TensorMapEntry) == 128, "TensorMapEntry must be exactly /** * Stack-allocated lookup result (avoids heap allocation per lookup) */ -#define PTO2_LOOKUP_MAX_RESULTS 16 +#define PTO2_LOOKUP_MAX_RESULTS 64 // ============================================================================= // TensorMap Lookup Chain Length Statistics (compile-time toggle) // ============================================================================= @@ -167,43 +170,56 @@ struct PTO2LookupResult { int32_t count{0}; void push(PTO2TensorMapEntry* entry, OverlapStatus s) { - if (count < PTO2_LOOKUP_MAX_RESULTS) { - entries[count++] = {entry, s}; - } + always_assert(count < PTO2_LOOKUP_MAX_RESULTS && + "TensorMap lookup overflow: too many overlapping producers"); + entries[count++] = {entry, s}; } }; +// ============================================================================= +// Per-Ring TensorMap — independent buckets, pool, and task tracking per ring +// ============================================================================= + /** - * TensorMap structure + * Per-ring TensorMap structure * - * Hash table with ring buffer entry pool and lazy invalidation. + * Each ring owns its own hash table, entry pool, free list, and per-task + * entry heads. Bucket chains contain only entries from this ring, enabling + * break-on-stale optimization (chains are ordered newest-first by head-insert). */ -struct PTO2TensorMap { +struct PTO2TensorMapRing { // Hash table buckets (fixed size, power of 2) - PTO2TensorMapEntry** buckets; // Array of offsets into entry_pool (-1 = empty) - int32_t num_buckets; // Must be power of 2 for fast modulo + PTO2TensorMapEntry** buckets{nullptr}; // Array of bucket head pointers (nullptr = empty) + int32_t num_buckets{0}; // Must be power of 2 for fast modulo // Entry pool as ring buffer - PTO2TensorMapEntry* entry_pool; // Ring buffer of entries - PTO2TensorMapEntry** free_entry_list; // free entry ids - int32_t pool_size; // Total pool capacity - int32_t next_entry_idx; // id when next entry insert - int32_t free_num; // free entry number in entry pool + PTO2TensorMapEntry* entry_pool{nullptr}; // Ring buffer of entries + PTO2TensorMapEntry** free_entry_list{nullptr}; // Free entry pointers + int32_t pool_size{0}; // Total pool capacity + int32_t next_entry_idx{0}; // Next fresh entry index + int32_t free_num{0}; // Free entry count - // Per-ring per-task entry tracking (for efficient bucket cleanup) - // Indexed by [ring_id][local_id & (task_window_sizes[ring_id] - 1)] - PTO2TensorMapEntry** task_entry_heads[PTO2_MAX_RING_DEPTH]; - int32_t task_window_sizes[PTO2_MAX_RING_DEPTH]; // Per-ring task window size (for slot masking) + // Per-task entry tracking (for efficient cleanup) + PTO2TensorMapEntry** task_entry_heads{nullptr}; // Indexed by local_id & (task_window_size - 1) + int32_t task_window_size{0}; // Task window size (for slot masking) - // Per-ring validity threshold (for lazy invalidation) - int32_t last_task_alives[PTO2_MAX_RING_DEPTH]; // Cached from shared memory per ring + // Validity threshold (for lazy invalidation) + int32_t last_task_alive{0}; // Cached from shared memory - // Per-ring cleanup progress (for periodic cleanup_retired) - int32_t last_cleanup[PTO2_MAX_RING_DEPTH]{}; + // Cleanup cursor (tracks how far cleanup has advanced) + int32_t last_cleanup{0}; - PTO2OrchestratorState* orch{nullptr}; + // ============================================================================= + // Initialization / Destruction + // ============================================================================= + + bool init(int32_t num_buckets, int32_t pool_size, int32_t task_window_size); + void destroy(); + + // ============================================================================= + // Entry Allocation + // ============================================================================= - // new_entry目前不负责分配属性,仅分配内存 PTO2TensorMapEntry* new_entry() { if (free_num > 0) { PTO2TensorMapEntry* res = free_entry_list[--free_num]; @@ -217,12 +233,10 @@ struct PTO2TensorMap { } void free_entry(PTO2TensorMapEntry& entry) { - always_assert(entry.bucket_index != -1); // 必须保证仍在桶中 + always_assert(entry.bucket_index != -1); // Update predecessor's next pointer (O(1) via prev_in_bucket) if (entry.prev_in_bucket == nullptr) { - // Entry is the head of its bucket chain, update bucket head - // Must compute hash BEFORE clearing tensor buckets[entry.bucket_index] = entry.next_in_bucket; } else { entry.prev_in_bucket->next_in_bucket = entry.next_in_bucket; @@ -242,81 +256,32 @@ struct PTO2TensorMap { } // ============================================================================= - // TensorMap API + // Lookup (with break-on-stale optimization) // ============================================================================= - /** - * Initialize TensorMap - * - * @param num_buckets Number of hash buckets (must be power of 2) - * @param pool_size Size of entry pool - * @return true on success, false on allocation failure - */ - bool init(int32_t num_buckets, int32_t pool_size, const int32_t task_window_sizes[PTO2_MAX_RING_DEPTH]); - - /** - * Initialize TensorMap with default sizes - */ - bool init_default(const int32_t task_window_sizes[PTO2_MAX_RING_DEPTH]); - - /** - * Destroy TensorMap and free resources - */ - void destroy(); - - /** - * Update validity threshold from shared memory - * Called periodically to refresh the lazy invalidation threshold. - * - * @param last_task_alive Current value from shared memory - */ - void sync_validity(int32_t ring_id, int32_t last_task_alive) { - this->last_task_alives[ring_id] = last_task_alive; - } - - /** - * Lookup producer for a tensor region - * - * Searches the hash table for a matching region. - * Returns producer entry if found and valid. - * Stale entries from different rings are skipped (not truncated). - * - * @param tensor Tensor to look up - * @param result Output: stack-allocated result buffer - */ void lookup(const Tensor& tensor, PTO2LookupResult& result) { uint32_t bucket_index = hash(tensor.buffer.addr); PTO2TensorMapEntry* cur_entry = buckets[bucket_index]; - result.count = 0; #if PTO2_TENSORMAP_PROFILING g_lookup_count++; int32_t chain_len = 0; #endif while (cur_entry != nullptr) { - // Prefetch next entry to hide pointer-chasing latency. - // entry_valid() + is_overlap() computation provides hide time. PTO2TensorMapEntry* next_entry = cur_entry->next_in_bucket; if (next_entry) __builtin_prefetch(next_entry, 0, 0); #if PTO2_TENSORMAP_PROFILING chain_len++; #endif - // Skip stale entries (no chain truncation — entries from different - // rings can be interleaved, so a stale entry from one ring does NOT - // imply subsequent entries from other rings are also stale) + // Per-ring chain: entries are ordered newest-first (head-insert). + // A stale entry means all subsequent entries are also stale — break. if (!entry_valid(*cur_entry)) { - cur_entry = next_entry; - continue; + break; } - // Entry is valid - check if regions OVERLAP (not just exact match) - // Since we hash only by base_ptr, all entries in this bucket have - // potential to overlap. We must check actual byte-range overlap. if (tensor.buffer.addr == cur_entry->buffer_addr) { - // Double prefetch: check_overlap provides enough hide time - // to also warm up the entry after next. if (next_entry) { PTO2TensorMapEntry* next_next = next_entry->next_in_bucket; if (next_next) __builtin_prefetch(next_next, 0, 0); @@ -333,7 +298,6 @@ struct PTO2TensorMap { } } - // Move to next entry cur_entry = next_entry; } #if PTO2_TENSORMAP_PROFILING @@ -342,110 +306,82 @@ struct PTO2TensorMap { #endif } - /** - * Insert a new entry (called when task produces output) - * - * Allocates from ring buffer pool, may overwrite stale entries. - * Inserts at head of hash bucket chain (maintains task_id ordering). - * - * @param tensor Tensor produced - * @param producer_task_id Task ID of producer - */ - void insert(const Tensor& tensor, PTO2TaskId producer_task_id, bool with_alloc) { + // ============================================================================= + // Insert + // ============================================================================= + + void insert(const Tensor& tensor, PTO2TaskId producer_task_id, bool with_alloc, + int32_t lifecycle_key, uint8_t owner_ring) { #if PTO2_TENSORMAP_PROFILING g_insert_count++; #endif - // Prefetch bucket head and task_entry_head early; new_entry() + field - // initialization below provides hide time for these RFOs. uint32_t bucket_index = hash(tensor.buffer.addr); __builtin_prefetch(&buckets[bucket_index], 1, 0); - auto ring_id = producer_task_id.ring(); - auto local_id = producer_task_id.local(); - int32_t task_slot = local_id & (task_window_sizes[ring_id] - 1); - __builtin_prefetch(&task_entry_heads[ring_id][task_slot], 1, 0); + int32_t task_slot = lifecycle_key & (task_window_size - 1); + __builtin_prefetch(&task_entry_heads[task_slot], 1, 0); - // Allocate entry from ring buffer pool PTO2TensorMapEntry* entry = new_entry(); - // Initialize new entry entry->copy_from_tensor(tensor); entry->producer_task_id = producer_task_id; entry->with_alloc = with_alloc; + entry->lifecycle_key = lifecycle_key; + entry->owner_ring = owner_ring; // Insert at head of hash bucket (maintains task_id descending order) entry->bucket_index = bucket_index; entry->next_in_bucket = buckets[bucket_index]; - // Update old head's prev pointer if (entry->next_in_bucket != nullptr) { entry->next_in_bucket->prev_in_bucket = entry; } buckets[entry->bucket_index] = entry; - entry->prev_in_bucket = nullptr; // New head has no predecessor + entry->prev_in_bucket = nullptr; - // Link to task's entry list (for cleanup), indexed by ring and local slot - entry->next_in_task = task_entry_heads[ring_id][task_slot]; - entry->prev_in_task = nullptr; // New head has no predecessor - // Update old head's prev pointer + // Link to task's entry list (for cleanup) + entry->next_in_task = task_entry_heads[task_slot]; + entry->prev_in_task = nullptr; if (entry->next_in_task != nullptr) { entry->next_in_task->prev_in_task = entry; } - task_entry_heads[ring_id][task_slot] = entry; + task_entry_heads[task_slot] = entry; } - /** - * Cleanup stale entries for retired tasks - * - * Called periodically by Orchestrator when last_task_alive advances. - * Removes entries from bucket chains for tasks in [old, new) range. - * - * @param old_last_task_alive Previous threshold - * @param new_last_task_alive New threshold - */ - void cleanup_retired(int32_t ring_id, int32_t old_last_task_alive, int32_t new_last_task_alive) { - // Iterate through retired tasks on this ring and remove their entries - for (int32_t local_id = old_last_task_alive; local_id < new_last_task_alive; local_id++) { - int32_t task_slot = local_id & (task_window_sizes[ring_id] - 1); - PTO2TensorMapEntry* cur_entry = task_entry_heads[ring_id][task_slot]; + // ============================================================================= + // Cleanup + // ============================================================================= + + void cleanup_retired(int32_t old_key, int32_t new_key) { + for (int32_t key = old_key; key < new_key; key++) { + int32_t task_slot = key & (task_window_size - 1); + PTO2TensorMapEntry* cur_entry = task_entry_heads[task_slot]; while (cur_entry != nullptr) { - PTO2TensorMapEntry* next_entry = cur_entry->next_in_task; // Save before clearing - // Only remove if this entry belongs to the retiring task - // (slot may have been reused by a newer task) - debug_assert(cur_entry->producer_task_id == - pto2_make_task_id(static_cast(ring_id), - static_cast(local_id))); + PTO2TensorMapEntry* next_entry = cur_entry->next_in_task; + debug_assert(cur_entry->lifecycle_key == key); free_entry(*cur_entry); cur_entry = next_entry; } - // Clear task's entry head (slot will be reused by local_id + task_window_sizes[ring_id]) - task_entry_heads[ring_id][task_slot] = nullptr; + task_entry_heads[task_slot] = nullptr; } } // ============================================================================= - // Internal Helpers (exposed for testing) + // Internal Helpers // ============================================================================= - /** - * Compute hash for tensor addr - */ uint32_t hash(uint64_t key) { - // Improve distribution by mixing bits (pointers often have aligned low bits) key = key ^ (key >> 16); key = key ^ (key >> 32); - - // Use bitwise AND for power-of-2 modulo (faster than %) return (uint32_t)(key & (num_buckets - 1)); } - /** - * Check if entry is valid (producer has not retired) - */ bool entry_valid(const PTO2TensorMapEntry& entry) const { - int32_t ring_id = pto2_task_id_ring(entry.producer_task_id); - int32_t local_id = static_cast(pto2_task_id_local(entry.producer_task_id)); - return local_id >= last_task_alives[ring_id]; + return entry.lifecycle_key >= last_task_alive; + } + + void sync_validity(int32_t new_last_task_alive) { + last_task_alive = new_last_task_alive; } void remove_entry(PTO2TensorMapEntry& entry) { @@ -453,24 +389,15 @@ struct PTO2TensorMap { free_entry(entry); } - /** - * Remove entry from its task chain (O(1) with prev pointer) - * Called during pool wrap-around to unlink reused entries. - */ void remove_from_task(PTO2TensorMapEntry& entry) { - always_assert(entry.bucket_index != -1); // 必须保证仍在桶中 - // Update predecessor's next pointer (O(1) via prev_in_task) + always_assert(entry.bucket_index != -1); + int32_t task_slot = entry.lifecycle_key & (task_window_size - 1); if (entry.prev_in_task == nullptr) { - // Entry is the head of its task chain, update task_entry_heads - int32_t ring_id = pto2_task_id_ring(entry.producer_task_id); - int32_t local_id = static_cast(pto2_task_id_local(entry.producer_task_id)); - int32_t task_slot = local_id & (task_window_sizes[ring_id] - 1); - task_entry_heads[ring_id][task_slot] = entry.next_in_task; + task_entry_heads[task_slot] = entry.next_in_task; } else { entry.prev_in_task->next_in_task = entry.next_in_task; } - // Update successor's prev pointer if (entry.next_in_task != nullptr) { entry.next_in_task->prev_in_task = entry.prev_in_task; } @@ -483,27 +410,316 @@ struct PTO2TensorMap { // Debug Utilities // ============================================================================= + void print_stats(); + int32_t valid_count(); +}; + +// ============================================================================= +// Fallback TensorMap — global buckets with producer-driven cleanup +// ============================================================================= + +/** + * Fallback TensorMap stores: + * - external tensor history + * - cross-ring INOUT modifier entries + * + * Storage ownership is global (single bucket table), but lifecycle ownership + * remains tied to the producer task. Cleanup therefore uses per-producer-ring + * task lists and does not rely on break-on-stale. + */ +struct PTO2FallbackTensorMap { + PTO2TensorMapEntry** buckets{nullptr}; + int32_t num_buckets{0}; + + PTO2TensorMapEntry* entry_pool{nullptr}; + PTO2TensorMapEntry** free_entry_list{nullptr}; + int32_t pool_size{0}; + int32_t next_entry_idx{0}; + int32_t free_num{0}; + + PTO2TensorMapEntry** lifecycle_entry_heads{nullptr}; + PTO2TaskId* lifecycle_producer_ids{nullptr}; + int32_t* lifecycle_entry_keys{nullptr}; + int32_t lifecycle_window_size{0}; + int32_t next_lifecycle_key{0}; + int32_t last_task_alive{0}; + uint64_t last_insert_task_raw{UINT64_MAX}; + int32_t last_insert_lifecycle_key{-1}; + + bool init(int32_t num_buckets, int32_t pool_size); + void destroy(); + + PTO2TensorMapEntry* new_entry() { + if (free_num > 0) { + PTO2TensorMapEntry* res = free_entry_list[--free_num]; + debug_assert(res->bucket_index == -1); + return res; + } + always_assert(next_entry_idx < pool_size); + PTO2TensorMapEntry* res = &entry_pool[next_entry_idx++]; + debug_assert(res->bucket_index == -1); + return res; + } + + void free_entry(PTO2TensorMapEntry& entry) { + always_assert(entry.bucket_index != -1); + + if (entry.prev_in_bucket == nullptr) { + buckets[entry.bucket_index] = entry.next_in_bucket; + } else { + entry.prev_in_bucket->next_in_bucket = entry.next_in_bucket; + } + + if (entry.next_in_bucket != nullptr) { + entry.next_in_bucket->prev_in_bucket = entry.prev_in_bucket; + } + + free_entry_list[free_num++] = &entry; + entry.bucket_index = -1; + entry.next_in_bucket = nullptr; + entry.prev_in_bucket = nullptr; + entry.next_in_task = nullptr; + entry.prev_in_task = nullptr; + } + + void lookup(const Tensor& tensor, PTO2LookupResult& result, + const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]) { + uint32_t bucket_index = hash(tensor.buffer.addr); + PTO2TensorMapEntry* cur_entry = buckets[bucket_index]; + + while (cur_entry != nullptr) { + PTO2TensorMapEntry* next_entry = cur_entry->next_in_bucket; + if (next_entry) __builtin_prefetch(next_entry, 0, 0); + + // lifecycle_key orders the fallback bucket chain. Once cleanup has + // advanced past this key, all later nodes in the chain are older. + if (cur_entry->lifecycle_key < last_task_alive) { + break; + } + + if (producer_alive(*cur_entry, owner_rings) && + tensor.buffer.addr == cur_entry->buffer_addr) { + auto overlap_status = cur_entry->check_overlap(tensor); + if (overlap_status != OverlapStatus::NO_OVERLAP) { + result.push(cur_entry, overlap_status); + } + } + + cur_entry = next_entry; + } + } + + void insert(const Tensor& tensor, PTO2TaskId producer_task_id, bool with_alloc) { + int32_t lifecycle_key = get_or_create_lifecycle_key(producer_task_id); + int32_t task_slot = lifecycle_key & (lifecycle_window_size - 1); + always_assert(lifecycle_entry_keys[task_slot] == -1 || + lifecycle_entry_keys[task_slot] == lifecycle_key); + + uint32_t bucket_index = hash(tensor.buffer.addr); + PTO2TensorMapEntry* entry = new_entry(); + + entry->copy_from_tensor(tensor); + entry->producer_task_id = producer_task_id; + entry->with_alloc = with_alloc; + entry->owner_ring = TENSOR_RING_ID_NONE; + entry->lifecycle_key = lifecycle_key; + + entry->bucket_index = bucket_index; + entry->next_in_bucket = buckets[bucket_index]; + if (entry->next_in_bucket != nullptr) { + entry->next_in_bucket->prev_in_bucket = entry; + } + buckets[bucket_index] = entry; + entry->prev_in_bucket = nullptr; + + entry->next_in_task = lifecycle_entry_heads[task_slot]; + entry->prev_in_task = nullptr; + if (entry->next_in_task != nullptr) { + entry->next_in_task->prev_in_task = entry; + } + lifecycle_entry_heads[task_slot] = entry; + lifecycle_entry_keys[task_slot] = lifecycle_key; + lifecycle_producer_ids[task_slot] = producer_task_id; + } + + void advance_frontier(const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]) { + while (last_task_alive < next_lifecycle_key) { + int32_t task_slot = last_task_alive & (lifecycle_window_size - 1); + if (lifecycle_entry_keys[task_slot] != last_task_alive) { + break; + } + + PTO2TensorMapEntry* cur_entry = lifecycle_entry_heads[task_slot]; + if (cur_entry == nullptr) { + lifecycle_entry_keys[task_slot] = -1; + lifecycle_producer_ids[task_slot] = PTO2TaskId{}; + last_task_alive++; + continue; + } + + PTO2TaskId producer_task_id = lifecycle_producer_ids[task_slot]; + int32_t producer_ring = static_cast(producer_task_id.ring()); + int32_t producer_local = static_cast(producer_task_id.local()); + if (producer_local >= owner_rings[producer_ring].last_task_alive) { + break; + } + + while (cur_entry != nullptr) { + PTO2TensorMapEntry* next_entry = cur_entry->next_in_task; + debug_assert(cur_entry->lifecycle_key == last_task_alive); + free_entry(*cur_entry); + cur_entry = next_entry; + } + + lifecycle_entry_heads[task_slot] = nullptr; + lifecycle_entry_keys[task_slot] = -1; + lifecycle_producer_ids[task_slot] = PTO2TaskId{}; + last_task_alive++; + } + } + + void remove_entry(PTO2TensorMapEntry& entry) { + remove_from_task(entry); + free_entry(entry); + } + + void remove_from_task(PTO2TensorMapEntry& entry) { + always_assert(entry.bucket_index != -1); + int32_t task_slot = entry.lifecycle_key & (lifecycle_window_size - 1); + if (entry.prev_in_task == nullptr) { + lifecycle_entry_heads[task_slot] = entry.next_in_task; + } else { + entry.prev_in_task->next_in_task = entry.next_in_task; + } + + if (entry.next_in_task != nullptr) { + entry.next_in_task->prev_in_task = entry.prev_in_task; + } + + entry.next_in_task = nullptr; + entry.prev_in_task = nullptr; + } + + uint32_t hash(uint64_t key) { + key = key ^ (key >> 16); + key = key ^ (key >> 32); + return (uint32_t)(key & (num_buckets - 1)); + } + + bool producer_alive(const PTO2TensorMapEntry& entry, + const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]) const { + int32_t producer_ring = static_cast(entry.producer_task_id.ring()); + int32_t producer_local = static_cast(entry.producer_task_id.local()); + always_assert(producer_ring >= 0 && producer_ring < PTO2_MAX_RING_DEPTH); + return producer_local >= owner_rings[producer_ring].last_task_alive; + } + + int32_t get_or_create_lifecycle_key(PTO2TaskId producer_task_id) { + if (last_insert_task_raw == producer_task_id.raw) { + return last_insert_lifecycle_key; + } + + always_assert(next_lifecycle_key - last_task_alive < lifecycle_window_size && + "Fallback TensorMap lifecycle window exhausted"); + + last_insert_task_raw = producer_task_id.raw; + last_insert_lifecycle_key = next_lifecycle_key++; + return last_insert_lifecycle_key; + } + + void print_stats(const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]); + int32_t valid_count(const PTO2TensorMapRing owner_rings[PTO2_MAX_RING_DEPTH]); +}; + +// ============================================================================= +// TensorMap Facade — routes operations to per-ring instances +// ============================================================================= + +/** + * TensorMap facade structure + * + * Routes lookup/insert/cleanup to the appropriate per-ring TensorMapRing + * based on tensor.ring_id. External tensors (ring_id < 0) are skipped. + */ +struct PTO2TensorMap { + PTO2TensorMapRing rings[PTO2_MAX_RING_DEPTH]; + PTO2FallbackTensorMap fallback_map; // Global fallback for cross-ring INOUT and external tensors + PTO2OrchestratorState* orch{nullptr}; + + // ============================================================================= + // TensorMap API + // ============================================================================= + + bool init(int32_t num_buckets, int32_t pool_size, const int32_t task_window_sizes[PTO2_MAX_RING_DEPTH]); + bool init_default(const int32_t task_window_sizes[PTO2_MAX_RING_DEPTH]); + void destroy(); + /** - * Print TensorMap statistics + * Dual-source lookup: query owner ring + fallback, append results to same container. + * - ring_id == TENSOR_RING_ID_NONE: only check fallback (external tensor) + * - ring_id valid: check owner ring, then fallback */ - void print_stats(); + void lookup(const Tensor& tensor, PTO2LookupResult& result) { + result.count = 0; + if (tensor.ring_id == TENSOR_RING_ID_NONE) { + fallback_map.lookup(tensor, result, rings); + } else if (tensor.ring_id < PTO2_MAX_RING_DEPTH) { + rings[tensor.ring_id].lookup(tensor, result); + fallback_map.lookup(tensor, result, rings); + } + } /** - * Get count of valid entries + * Route insert to owner ring or fallback based on tensor ownership and current ring. + * - OUTPUT (with_alloc=true): always owner ring + * - INOUT (with_alloc=false): owner ring if same ring, fallback if cross-ring or external */ - int32_t valid_count(); + void insert(const Tensor& tensor, PTO2TaskId producer_task_id, bool with_alloc, int32_t current_ring_id) { + if (with_alloc) { + if (tensor.ring_id == TENSOR_RING_ID_NONE) { + fallback_map.insert(tensor, producer_task_id, with_alloc); + return; + } + always_assert(tensor.ring_id < PTO2_MAX_RING_DEPTH); + always_assert(static_cast(tensor.ring_id) == current_ring_id && + "OUTPUT tensor ring_id must match submit ring"); + rings[tensor.ring_id].insert(tensor, producer_task_id, with_alloc, + static_cast(producer_task_id.local()), tensor.ring_id); + } else { + // INOUT modifier + if (tensor.ring_id == TENSOR_RING_ID_NONE) { + // External tensor → fallback + fallback_map.insert(tensor, producer_task_id, with_alloc); + } else if (tensor.ring_id < PTO2_MAX_RING_DEPTH + && static_cast(tensor.ring_id) == current_ring_id) { + // Same ring → owner ring + rings[tensor.ring_id].insert(tensor, producer_task_id, with_alloc, + static_cast(producer_task_id.local()), tensor.ring_id); + } else if (tensor.ring_id < PTO2_MAX_RING_DEPTH) { + // Cross-ring INOUT → fallback + fallback_map.insert(tensor, producer_task_id, with_alloc); + } + } + } - // ============================================================================= - // TensorMap Synchronization - // ============================================================================= + /** + * Route remove to owner ring or fallback based on entry's owner_ring field. + */ + void remove_entry(PTO2TensorMapEntry& entry) { + if (entry.owner_ring == TENSOR_RING_ID_NONE) { + fallback_map.remove_entry(entry); + } else if (entry.owner_ring < PTO2_MAX_RING_DEPTH) { + rings[entry.owner_ring].remove_entry(entry); + } + } + + void print_stats(); + int32_t valid_count(); /** - * Sync TensorMap validity threshold from shared memory - * - * Called periodically to refresh the lazy invalidation threshold. - * Also triggers cleanup if threshold has advanced significantly. + * Sync TensorMap validity threshold from shared memory and run cleanup. */ - void sync_tensormap(uint8_t ring_id, int32_t sm_last_task_alive); + void sync_tensormap(); }; #if PTO2_TENSORMAP_PROFILING diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/tensor.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/tensor.h index 10b5b582..e2d6c7b8 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/tensor.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/tensor.h @@ -9,6 +9,7 @@ #include "data_type.h" constexpr int RUNTIME_MAX_TENSOR_DIMS = 5; +constexpr uint8_t TENSOR_RING_ID_NONE = 0xFF; // No ring assigned (external tensor) /** * Buffer Handle @@ -66,6 +67,7 @@ struct alignas(64) Tensor { bool is_all_offset_zero; // True when all offsets[] are zero (skip offset read/write) bool is_raw_eq_shapes; // True when raw_shapes[] == shapes[] (skip raw_shapes read/write) bool manual_dep; // True when dependency is managed manually (skip tensormap lookup/insert) + uint8_t ring_id; // Ring that owns this tensor (TENSOR_RING_ID_NONE = unassigned) uint32_t shapes[RUNTIME_MAX_TENSOR_DIMS]; // Current view shape per dimension uint32_t __padding__; @@ -96,9 +98,10 @@ struct alignas(64) Tensor { int32_t version, bool is_all_offset_zero = false, bool is_raw_eq_shapes = false, - bool manual_dep = false) { + bool manual_dep = false, + uint8_t in_ring_id = TENSOR_RING_ID_NONE) { init(addr, buffer_size_bytes, raw_shapes, shapes, offsets, ndims, dtype, version, - is_all_offset_zero, is_raw_eq_shapes, manual_dep); + is_all_offset_zero, is_raw_eq_shapes, manual_dep, in_ring_id); } // --- Initialization --- @@ -112,7 +115,8 @@ struct alignas(64) Tensor { int32_t in_version, bool in_is_all_offset_zero = false, bool in_is_raw_eq_shapes = false, - bool in_manual_dep = false) { + bool in_manual_dep = false, + uint8_t in_ring_id = TENSOR_RING_ID_NONE) { buffer = {reinterpret_cast(addr), buffer_size_bytes}; ndims = in_ndims; dtype = in_dtype; @@ -120,6 +124,7 @@ struct alignas(64) Tensor { is_all_offset_zero = in_is_all_offset_zero; is_raw_eq_shapes = in_is_raw_eq_shapes; manual_dep = in_manual_dep; + ring_id = in_ring_id; for (uint32_t i = 0; i < in_ndims; i++) { shapes[i] = in_shapes[i]; } @@ -155,6 +160,7 @@ struct alignas(64) Tensor { dtype = other.dtype; version = other.version; manual_dep = in_manual_dep; + ring_id = other.ring_id; // view always diverges shapes from raw_shapes, so is_raw_eq_shapes = false. // Read parent's effective raw_shapes (avoids parent cache line 2 when parent is_raw_eq_shapes). is_raw_eq_shapes = false; @@ -285,6 +291,7 @@ struct alignas(64) Tensor { ss << indent << "dtype: " << get_dtype_name(dtype) << std::endl; ss << indent << "ndims: " << ndims << std::endl; ss << indent << "version: " << version << std::endl; + ss << indent << "ring_id: " << (unsigned)ring_id << std::endl; const uint32_t* rs = get_raw_shapes(); ss << indent << "raw_shapes: ["; @@ -320,44 +327,3 @@ static_assert(sizeof(Tensor) == 128, "Tensor must be exactly 2 cache lines (128 static_assert(offsetof(Tensor, raw_shapes) == 64); using TensorData = Tensor; - -// ============================================================================= -// Factory Helpers -// ============================================================================= -/** - * Create a Tensor for pre-allocated external memory. - */ -static inline Tensor make_tensor_external(void* addr, - const uint32_t shapes[], - uint32_t ndims, - DataType dtype = DataType::FLOAT32, - bool manual_dep = false, - int32_t version = 0) { - static uint32_t zero_offsets[RUNTIME_MAX_TENSOR_DIMS] = {}; - uint64_t total = 1; - for (uint32_t i = 0; i < ndims; i++) { - total *= shapes[i]; - } - return Tensor(addr, total * get_element_size(dtype), shapes, shapes, zero_offsets, ndims, dtype, version, - /*is_all_offset_zero=*/true, /*is_raw_eq_shapes=*/true, manual_dep); -} - -/** - * Create a Tensor for runtime-allocated output (addr=0). - * NO memory allocation: only records dtype, shape, and buffer.size in the Tensor struct. - * The runtime allocates from the heap ring and fills buffer.addr during pto2_submit_task - * when this tensor is passed as OUTPUT param. No buffer content is ever copied. - */ -static inline Tensor make_tensor(const uint32_t shapes[], - uint32_t ndims, - DataType dtype = DataType::FLOAT32, - bool manual_dep = false, - int32_t version = 0) { - static uint32_t zero_offsets[RUNTIME_MAX_TENSOR_DIMS] = {}; - uint64_t total = 1; - for (uint32_t i = 0; i < ndims; i++) { - total *= shapes[i]; - } - return Tensor(0, total * get_element_size(dtype), shapes, shapes, zero_offsets, ndims, dtype, version, - /*is_all_offset_zero=*/true, /*is_raw_eq_shapes=*/true, manual_dep); -} diff --git a/tests/device_tests/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp b/tests/device_tests/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp index 8a13111b..6a2a7952 100644 --- a/tests/device_tests/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp +++ b/tests/device_tests/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp @@ -192,7 +192,7 @@ void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count, i params_up.add_input(oi_new_b); params_up.add_inout(mi_batch); params_up.add_inout(li_batch); - params_up.add_output(oi_batch); + params_up.add_inout(oi_batch); params_up.add_output(out); params_up.add_scalar(is_first); params_up.add_scalar(is_last);