From e20c3b241b1422d8b9e3b2eace5ec89427cd52ca Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sat, 16 May 2026 18:09:20 -0400 Subject: [PATCH 1/6] Fix numba version to <= 0.61 On 0.65 I'm getting compilation errors for our code. --- deep_quoridor/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_quoridor/requirements.txt b/deep_quoridor/requirements.txt index 86d3859c..d39fd374 100644 --- a/deep_quoridor/requirements.txt +++ b/deep_quoridor/requirements.txt @@ -1,7 +1,7 @@ gymnasium maturin matplotlib -numba +numba<=0.61 numpy open_spiel pettingzoo From 99ac34d91d04c7dfebe80151617371cb380fce53 Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sat, 16 May 2026 18:10:24 -0400 Subject: [PATCH 2/6] Add onnx packages to requirements.txt Needed for rust self-play code. --- deep_quoridor/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deep_quoridor/requirements.txt b/deep_quoridor/requirements.txt index d39fd374..c1b6c453 100644 --- a/deep_quoridor/requirements.txt +++ b/deep_quoridor/requirements.txt @@ -3,6 +3,8 @@ maturin matplotlib numba<=0.61 numpy +onnx +onnxscript open_spiel pettingzoo prettytable From d91821f4ce0c3c7e0e9f8543c966dcfcddca56ea Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sat, 16 May 2026 18:47:47 -0400 Subject: [PATCH 3/6] Migrate AlphaZero MCTS pipeline to compact u64 state ActionSelector, Evaluator, MCTS nodes, and game_runner now operate on (u64 data, &QGameMechanics) instead of cloning GameState. MCTS nodes store their data eagerly at creation, removing the lazy get_or_create_game caching path. Adds compact_state_to_resnet_input and rotate_compact_state mirroring the existing GameState equivalents, plus get_action_mask / apply_action_index / is_game_over / winner helpers on QGameMechanics. game_runner only materializes a GameState for observer callbacks. Co-Authored-By: Claude Opus 4.7 (1M context) --- .vscode/settings.json | 2 +- .../rust/src/agents/alphazero/agent.rs | 26 +- .../rust/src/agents/alphazero/evaluator.rs | 55 ++- .../rust/src/agents/alphazero/mcts.rs | 365 ++++++++++-------- deep_quoridor/rust/src/agents/mod.rs | 14 +- deep_quoridor/rust/src/agents/onnx_agent.rs | 9 +- deep_quoridor/rust/src/agents/random_agent.rs | 23 +- .../rust/src/compact/q_game_mechanics.rs | 70 ++++ deep_quoridor/rust/src/game_runner.rs | 181 +++------ deep_quoridor/rust/src/grid_helpers.rs | 107 +++++ deep_quoridor/rust/src/python_consistency.rs | 19 +- deep_quoridor/rust/src/rotation.rs | 54 +++ 12 files changed, 581 insertions(+), 344 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 51bdfc8d..5d4c3c8d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -29,7 +29,7 @@ "source.organizeImports": "explicit" }, "rust-analyzer.linkedProjects": [ - "./deep_quoridor/rust/Cargo.toml" + "${workspaceFolder}/deep_quoridor/rust/Cargo.toml" ], "python-envs.pythonProjects": [], "chat.tools.terminal.outputLocation": "terminal", diff --git a/deep_quoridor/rust/src/agents/alphazero/agent.rs b/deep_quoridor/rust/src/agents/alphazero/agent.rs index 7774e2be..f3f9371b 100644 --- a/deep_quoridor/rust/src/agents/alphazero/agent.rs +++ b/deep_quoridor/rust/src/agents/alphazero/agent.rs @@ -8,7 +8,7 @@ use anyhow::Result; use rand::Rng; use crate::agents::{ActionSelectionTrace, ActionSelector}; -use crate::game_state::GameState; +use crate::compact::q_game_mechanics::QGameMechanics; use super::evaluator::OnnxEvaluator; use super::mcts::{search, MCTSConfig}; @@ -144,7 +144,8 @@ impl AlphaZeroAgent { impl ActionSelector for AlphaZeroAgent { fn select_action( &mut self, - state: &GameState, + data: u64, + mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(usize, Vec)> { // Run MCTS search - only pass visited states when penalization is enabled @@ -156,7 +157,8 @@ impl ActionSelector for AlphaZeroAgent { }; let (children, root_value) = search( &self.config.mcts, - state.clone(), + data, + mechanics, &mut self.evaluator, visited_ref, )?; @@ -166,8 +168,9 @@ impl ActionSelector for AlphaZeroAgent { let action_indices: Vec = children.iter().map(|c| c.action_index).collect(); // Determine effective temperature + let completed_steps = mechanics.repr().get_completed_steps(data); let temperature = if let Some(threshold) = self.config.drop_t_on_step { - if state.completed_steps >= threshold { + if completed_steps >= threshold { 0.0 } else { self.config.temperature @@ -193,16 +196,13 @@ impl ActionSelector for AlphaZeroAgent { } } - // Optionally add state to visited set + // Optionally add the resulting state to the visited set. + // Keyed on raw u64 state including completed_steps; a position seen at + // two different step counts is treated as two states. if self.config.penalize_visited_states { - // Get the hash of the resulting state after taking the action - let action = children - .iter() - .find(|c| c.action_index == selected_idx) - .map(|c| c.action) - .unwrap_or([0, 0, 2]); - let next_state = state.clone_and_step(action); - self.visited_states.insert(next_state.get_fast_hash()); + let mut next_data = data; + mechanics.apply_action_index(&mut next_data, selected_idx); + self.visited_states.insert(next_data); } self.last_selection_trace = Some(ActionSelectionTrace { diff --git a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs index 7b475452..eab7e788 100644 --- a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs +++ b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs @@ -6,15 +6,20 @@ use anyhow::{Context, Result}; use ort::session::Session; use crate::agents::onnx_agent::softmax; -use crate::game_state::GameState; -use crate::grid_helpers::grid_game_state_to_resnet_input; -use crate::rotation::{build_rotated_state, create_rotation_mapping, remap_policy}; +use crate::compact::q_game_mechanics::QGameMechanics; +use crate::grid_helpers::compact_state_to_resnet_input; +use crate::rotation::{create_rotation_mapping, remap_policy, rotate_compact_state}; /// Trait for evaluating game positions. /// /// Returns `(value_for_current_player, masked_softmax_priors)`. pub trait Evaluator { - fn evaluate(&mut self, state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec)>; + fn evaluate( + &mut self, + data: u64, + mechanics: &QGameMechanics, + action_mask: &[bool], + ) -> Result<(f32, Vec)>; } /// ONNX-based evaluator for MCTS. @@ -46,26 +51,39 @@ impl OnnxEvaluator { } impl Evaluator for OnnxEvaluator { - fn evaluate(&mut self, state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec)> { + fn evaluate( + &mut self, + data: u64, + mechanics: &QGameMechanics, + action_mask: &[bool], + ) -> Result<(f32, Vec)> { + let bs = mechanics.repr().board_size() as i32; + let current_player = mechanics.repr().get_current_player(data); + let rotated_to_original = self .rotated_to_original_by_board_size - .entry(state.board_size) - .or_insert_with(|| create_rotation_mapping(state.board_size).1); - let (work_state, work_action_mask, rotated_to_original) = if state.current_player == 1 { - let rotated_state = build_rotated_state(state); - let mask = rotated_state.get_action_mask(); - (rotated_state, mask, Some(rotated_to_original.as_slice())) + .entry(bs) + .or_insert_with(|| create_rotation_mapping(bs).1); + + let (work_data, work_action_mask, rotated_to_original) = if current_player == 1 { + let rotated_data = rotate_compact_state(mechanics, data); + let rotated_mask = mechanics.get_action_mask_immut(rotated_data); + ( + rotated_data, + rotated_mask, + Some(rotated_to_original.as_slice()), + ) } else { - (state.clone(), action_mask.to_vec(), None) + (data, action_mask.to_vec(), None) }; // Build ResNet input tensor - let resnet_input = grid_game_state_to_resnet_input(&work_state); + let resnet_input = compact_state_to_resnet_input(mechanics, work_data); // Convert to flat vec for ORT let shape = resnet_input.shape().to_vec(); - let data: Vec = resnet_input.iter().copied().collect(); - let input_value = ort::value::Value::from_array((shape.as_slice(), data)) + let input_data: Vec = resnet_input.iter().copied().collect(); + let input_value = ort::value::Value::from_array((shape.as_slice(), input_data)) .context("Failed to create ONNX input value")?; // Run inference @@ -98,7 +116,12 @@ impl Evaluator for OnnxEvaluator { } impl Evaluator for UniformMockEvaluator { - fn evaluate(&mut self, _state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec)> { + fn evaluate( + &mut self, + _data: u64, + _mechanics: &QGameMechanics, + action_mask: &[bool], + ) -> Result<(f32, Vec)> { let valid_count = action_mask.iter().filter(|&&valid| valid).count(); let mut priors = vec![0.0f32; action_mask.len()]; if valid_count > 0 { diff --git a/deep_quoridor/rust/src/agents/alphazero/mcts.rs b/deep_quoridor/rust/src/agents/alphazero/mcts.rs index 34948822..27655a80 100644 --- a/deep_quoridor/rust/src/agents/alphazero/mcts.rs +++ b/deep_quoridor/rust/src/agents/alphazero/mcts.rs @@ -6,8 +6,10 @@ use std::collections::HashSet; use rand_distr::{Dirichlet, Distribution}; -use crate::actions::{action_index_to_action, action_to_index}; -use crate::game_state::GameState; +use crate::actions::action_index_to_action; +#[cfg(test)] +use crate::actions::action_to_index; +use crate::compact::q_game_mechanics::QGameMechanics; use super::evaluator::Evaluator; @@ -58,12 +60,12 @@ pub struct ChildInfo { /// A node in the MCTS tree. #[derive(Debug)] pub struct Node { - /// The game state at this node. None for lazily expanded nodes. - pub game: Option, + /// The compact game state at this node. + pub data: u64, /// Parent node index in the arena, None for root. pub parent: Option, - /// Action taken from parent to reach this node. - pub action_taken: Option<[i32; 3]>, + /// Flat policy index for the action taken from parent to reach this node. + pub action_index: Option, /// Child node indices in the arena. pub children: Vec, /// Number of times this node was visited. @@ -80,11 +82,11 @@ pub struct Node { impl Node { /// Create a new root node. - pub fn new_root(game: GameState) -> Self { + pub fn new_root(data: u64) -> Self { Self { - game: Some(game), + data, parent: None, - action_taken: None, + action_index: None, children: Vec::new(), visit_count: 0, value_sum: 0.0, @@ -94,12 +96,12 @@ impl Node { } } - /// Create a new child node (lazy - game state computed on demand). - pub fn new_child(parent: usize, action: [i32; 3], prior: f32) -> Self { + /// Create a new child node with its (already computed) state. + pub fn new_child(parent: usize, action_index: usize, data: u64, prior: f32) -> Self { Self { - game: None, + data, parent: Some(parent), - action_taken: Some(action), + action_index: Some(action_index), children: Vec::new(), visit_count: 0, value_sum: 0.0, @@ -131,15 +133,22 @@ pub struct NodeArena { impl NodeArena { /// Create a new arena with a root node. - pub fn new(root_game: GameState) -> Self { - let root = Node::new_root(root_game); + pub fn new(root_data: u64) -> Self { + let root = Node::new_root(root_data); Self { nodes: vec![root] } } /// Allocate a new child node and return its index. - pub fn alloc_child(&mut self, parent: usize, action: [i32; 3], prior: f32) -> usize { + pub fn alloc_child( + &mut self, + parent: usize, + action_index: usize, + data: u64, + prior: f32, + ) -> usize { let idx = self.nodes.len(); - self.nodes.push(Node::new_child(parent, action, prior)); + self.nodes + .push(Node::new_child(parent, action_index, data, prior)); idx } @@ -153,47 +162,31 @@ impl NodeArena { &mut self.nodes[idx] } - /// Get the game state for a node, computing it lazily if needed. - pub fn get_or_create_game(&mut self, idx: usize) -> &GameState { - // First check if we already have it - if self.nodes[idx].game.is_some() { - return self.nodes[idx].game.as_ref().unwrap(); - } - - // Need to compute it from parent - let parent_idx = self.nodes[idx] - .parent - .expect("Non-root node must have parent"); - let action = self.nodes[idx] - .action_taken - .expect("Child node must have action"); - - // Recursively ensure parent has game state - let parent_game = self.get_or_create_game(parent_idx).clone(); - - // Apply action to get child game state - let child_game = parent_game.clone_and_step(action); - self.nodes[idx].game = Some(child_game); - - self.nodes[idx].game.as_ref().unwrap() - } - /// Get the number of nodes. pub fn len(&self) -> usize { self.nodes.len() } } -/// Expand a node by creating children for all valid actions. -pub fn expand_node(arena: &mut NodeArena, node_idx: usize, priors: &[f32], board_size: i32) { - // Create children only for actions with non-zero prior - for (action_idx, &prior) in priors.iter().enumerate() { - if prior > 1e-10 { - let action = action_index_to_action(board_size, action_idx); - let child_idx = arena.alloc_child(node_idx, action, prior); - arena.get_mut(node_idx).children.push(child_idx); - } - } +/// Expand a node by creating children for all actions with non-zero prior. +pub fn expand_node( + arena: &mut NodeArena, + node_idx: usize, + priors: &[f32], + mechanics: &QGameMechanics, +) { + let parent_data = arena.get(node_idx).data; + let new_children: Vec = priors + .iter() + .enumerate() + .filter(|(_, &p)| p > 1e-10) + .map(|(action_idx, &prior)| { + let mut child_data = parent_data; + mechanics.apply_action_index(&mut child_data, action_idx); + arena.alloc_child(node_idx, action_idx, child_data, prior) + }) + .collect(); + arena.get_mut(node_idx).children.extend(new_children); } /// Select the best child using PUCT formula. @@ -204,7 +197,6 @@ pub fn select_child( node_idx: usize, ucb_c: f32, visited_states: &HashSet, - _board_size: i32, ) -> usize { let node = arena.get(node_idx); let parent_visits = node.visit_count.max(1) as f32; @@ -220,21 +212,10 @@ pub fn select_child( let u = ucb_c * child.prior * (parent_visits.sqrt()) / (1.0 + child.visit_count as f32); let mut ucb = q + u; - // Apply penalty if child state is in visited set (only when enabled) - // NOTE: If the penalized_visited_states flag is false, the agent will pass an empty set, so this check will have no effect. - if !visited_states.is_empty() { - if let Some(ref game) = child.game { - if visited_states.contains(&game.get_fast_hash()) { - ucb -= 1.0; - } - } else if let Some(action) = child.action_taken { - // Compute hash for child without fully expanding game - let parent_game = arena.get(node_idx).game.as_ref().unwrap(); - let child_game = parent_game.clone_and_step(action); - if visited_states.contains(&child_game.get_fast_hash()) { - ucb -= 1.0; - } - } + // Penalty when this child's state is in the visited set. + // Set is empty when penalize_visited_states is false. + if !visited_states.is_empty() && visited_states.contains(&child.data) { + ucb -= 1.0; } if ucb > best_ucb { @@ -313,19 +294,20 @@ pub fn apply_dirichlet_noise(priors: &mut [f32], epsilon: f32, alpha: f32) { /// Run MCTS search and return child information. pub fn search( config: &MCTSConfig, - game: GameState, + root_data: u64, + mechanics: &QGameMechanics, evaluator: &mut E, visited_states: &HashSet, ) -> anyhow::Result<(Vec, f32)> { - let board_size = game.board_size; - let mut arena = NodeArena::new(game.clone()); + let bs = mechanics.repr().board_size() as i32; + let mut arena = NodeArena::new(root_data); // Evaluate root upfront to get root_value and priors. // Root is NOT expanded here; expansion happens inside the first loop iteration // so that the loop structure matches the Python implementation (where iteration 0 // always selects root itself, expands it, and backpropagates through it). - let action_mask = game.get_action_mask(); - let (root_value, mut root_priors) = evaluator.evaluate(&game, &action_mask)?; + let action_mask = mechanics.get_action_mask_immut(root_data); + let (root_value, mut root_priors) = evaluator.evaluate(root_data, mechanics, &action_mask)?; // Apply Dirichlet noise at root if configured if config.noise_epsilon > 0.0 { @@ -344,7 +326,7 @@ pub fn search( // Special case: n=0 means just use priors (expand root once without simulating) if n_iterations == 0 { - expand_node(&mut arena, 0, &root_priors, board_size); + expand_node(&mut arena, 0, &root_priors, mechanics); let root = arena.get(0); let children = root.children.clone(); @@ -364,22 +346,15 @@ pub fn search( let mut current_idx = 0; while !arena.get(current_idx).should_expand() { - current_idx = select_child( - &arena, - current_idx, - config.ucb_c, - visited_states, - board_size, - ); + current_idx = select_child(&arena, current_idx, config.ucb_c, visited_states); } - // Get/create game state for the selected node - let leaf_game = arena.get_or_create_game(current_idx).clone(); + let leaf_data = arena.get(current_idx).data; // Check for terminal state - if leaf_game.is_game_over() { + if mechanics.is_game_over(leaf_data) { // Terminal: backpropagate result - let value = if leaf_game.winner().is_some() { + let value = if mechanics.winner(leaf_data).is_some() { 1.0 } else { 0.0 @@ -390,7 +365,7 @@ pub fn search( // Check max steps if let Some(max) = config.max_steps { - if leaf_game.completed_steps >= max as usize { + if mechanics.repr().get_completed_steps(leaf_data) >= max as usize { backpropagate_result(&mut arena, current_idx, 0.0); continue; } @@ -402,11 +377,11 @@ pub fn search( let (value, leaf_priors) = if let Some(pre_priors) = root_priors_opt.take() { (root_value, pre_priors) } else { - let leaf_mask = leaf_game.get_action_mask(); - evaluator.evaluate(&leaf_game, &leaf_mask)? + let leaf_mask = mechanics.get_action_mask_immut(leaf_data); + evaluator.evaluate(leaf_data, mechanics, &leaf_mask)? }; - expand_node(&mut arena, current_idx, &leaf_priors, board_size); + expand_node(&mut arena, current_idx, &leaf_priors, mechanics); // Backpropagate negative value (from opponent's perspective) backpropagate(&mut arena, current_idx, -value as f64); @@ -426,11 +401,11 @@ pub fn search( .iter() .map(|&child_idx| { let child = arena.get(child_idx); - let action = child.action_taken.unwrap(); - let action_index = action_to_index(board_size, &action); + let ai = child.action_index.expect("child node must have action_index"); + let action = action_index_to_action(bs, ai); ChildInfo { action, - action_index, + action_index: ai, visit_count: child.visit_count, } }) @@ -458,7 +433,8 @@ mod tests { impl Evaluator for MockEvaluator { fn evaluate( &mut self, - _state: &GameState, + _data: u64, + _mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(f32, Vec)> { // Return uniform priors over valid actions @@ -476,12 +452,18 @@ mod tests { } } + fn make_mech_state() -> (QGameMechanics, u64) { + let mech = QGameMechanics::new(5, 3, 200); + let data = mech.create_initial_state(); + (mech, data) + } + #[test] fn test_node_creation() { - let state = GameState::new(5, 3); - let node = Node::new_root(state); + let (_, data) = make_mech_state(); + let node = Node::new_root(data); - assert!(node.game.is_some()); + assert_eq!(node.data, data); assert!(node.parent.is_none()); assert!(node.children.is_empty()); assert_eq!(node.visit_count, 0); @@ -491,37 +473,60 @@ mod tests { #[test] fn test_expand_node() { - let state = GameState::new(5, 3); - let mut arena = NodeArena::new(state); - - // Create uniform priors with some zeros - let priors = vec![0.0, 0.5, 0.0, 0.3, 0.2]; + let (mech, data) = make_mech_state(); + let mut arena = NodeArena::new(data); + + // Build sparse priors aligned to the policy layout + let total = crate::actions::policy_size(5); + let mask = mech.get_action_mask_immut(data); + let mut priors = vec![0.0f32; total]; + // Put non-zero prior on three valid actions + let mut count = 0; + for (i, &v) in mask.iter().enumerate() { + if v { + priors[i] = if count == 0 { + 0.5 + } else if count == 1 { + 0.3 + } else if count == 2 { + 0.2 + } else { + 0.0 + }; + count += 1; + if count == 3 { + break; + } + } + } - expand_node(&mut arena, 0, &priors, 5); + expand_node(&mut arena, 0, &priors, &mech); let root = arena.get(0); - // Should have 3 children (for non-zero priors) assert_eq!(root.children.len(), 3); - - // Verify children have correct priors - let child_priors: Vec = root - .children - .iter() - .map(|&idx| arena.get(idx).prior) - .collect(); - assert!((child_priors[0] - 0.5).abs() < 1e-6); - assert!((child_priors[1] - 0.3).abs() < 1e-6); - assert!((child_priors[2] - 0.2).abs() < 1e-6); } #[test] fn test_select_child_ucb() { - let state = GameState::new(5, 3); - let mut arena = NodeArena::new(state); + let (mech, data) = make_mech_state(); + let mut arena = NodeArena::new(data); - // Manually create children with different visit counts and values - let child1 = arena.alloc_child(0, [1, 2, 2], 0.5); - let child2 = arena.alloc_child(0, [0, 2, 2], 0.5); + // Find a couple of valid move action indices to use + let mask = mech.get_action_mask_immut(data); + let valid: Vec = mask + .iter() + .enumerate() + .filter_map(|(i, &v)| if v { Some(i) } else { None }) + .collect(); + assert!(valid.len() >= 2); + + // Manually create two children + let mut d1 = data; + mech.apply_action_index(&mut d1, valid[0]); + let mut d2 = data; + mech.apply_action_index(&mut d2, valid[1]); + let child1 = arena.alloc_child(0, valid[0], d1, 0.5); + let child2 = arena.alloc_child(0, valid[1], d2, 0.5); arena.get_mut(0).children = vec![child1, child2]; arena.get_mut(0).visit_count = 10; @@ -535,28 +540,25 @@ mod tests { arena.get_mut(child2).value_sum = 0.0; let visited = HashSet::new(); - let selected = select_child(&arena, 0, 1.4, &visited, 5); + let selected = select_child(&arena, 0, 1.4, &visited); - // Child 2 should be selected (higher UCB due to fewer visits) assert_eq!(selected, child2); } #[test] fn test_backpropagate() { - let state = GameState::new(5, 3); - let mut arena = NodeArena::new(state); + let (_, data) = make_mech_state(); + let mut arena = NodeArena::new(data); - // Create a chain: root -> child -> grandchild - let child = arena.alloc_child(0, [1, 2, 2], 0.5); - let grandchild = arena.alloc_child(child, [2, 2, 2], 0.5); + // Chain: root -> child -> grandchild (arbitrary action indices, data doesn't matter here) + let child = arena.alloc_child(0, 0, data, 0.5); + let grandchild = arena.alloc_child(child, 0, data, 0.5); arena.get_mut(0).children = vec![child]; arena.get_mut(child).children = vec![grandchild]; - // Backpropagate +1 from grandchild backpropagate(&mut arena, grandchild, 1.0); - // Check values alternate assert_eq!(arena.get(grandchild).visit_count, 1); assert!((arena.get(grandchild).value_sum - 1.0).abs() < 1e-6); @@ -569,13 +571,12 @@ mod tests { #[test] fn test_backpropagate_result_tracks_wins_losses() { - let state = GameState::new(5, 3); - let mut arena = NodeArena::new(state); + let (_, data) = make_mech_state(); + let mut arena = NodeArena::new(data); - let child = arena.alloc_child(0, [1, 2, 2], 0.5); + let child = arena.alloc_child(0, 0, data, 0.5); arena.get_mut(0).children = vec![child]; - // Backpropagate a win backpropagate_result(&mut arena, child, 1.0); assert_eq!(arena.get(child).wins, 1); @@ -586,44 +587,60 @@ mod tests { #[test] fn test_visited_state_penalty() { - let state = GameState::new(5, 0); // No walls for simpler moves - let mut arena = NodeArena::new(state.clone()); - - // Create two children - let child1 = arena.alloc_child(0, [1, 2, 2], 0.5); - let child2 = arena.alloc_child(0, [0, 1, 2], 0.5); + let mech = QGameMechanics::new(5, 0, 200); + let data = mech.create_initial_state(); + let mut arena = NodeArena::new(data); + // Two valid move children + let mask = mech.get_action_mask_immut(data); + let valid: Vec = mask + .iter() + .enumerate() + .filter_map(|(i, &v)| if v { Some(i) } else { None }) + .collect(); + let mut d1 = data; + mech.apply_action_index(&mut d1, valid[0]); + let mut d2 = data; + mech.apply_action_index(&mut d2, valid[1]); + let child1 = arena.alloc_child(0, valid[0], d1, 0.5); + let child2 = arena.alloc_child(0, valid[1], d2, 0.5); arena.get_mut(0).children = vec![child1, child2]; arena.get_mut(0).visit_count = 10; - // Equal visits arena.get_mut(child1).visit_count = 1; arena.get_mut(child2).visit_count = 1; - // Mark child1's state as visited - let child1_game = state.clone_and_step([1, 2, 2]); + // Mark child1's data as visited let mut visited = HashSet::new(); - visited.insert(child1_game.get_fast_hash()); + visited.insert(d1); - // Child2 should be selected since child1 has penalty - let selected = select_child(&arena, 0, 1.4, &visited, 5); + let selected = select_child(&arena, 0, 1.4, &visited); assert_eq!(selected, child2); } #[test] fn test_no_visited_state_penalty_when_set_is_empty() { - // When penalize_visited_states is false, the agent passes an empty set, - // so no penalty should be applied even if a state would match. - let state = GameState::new(5, 0); - let mut arena = NodeArena::new(state.clone()); + let mech = QGameMechanics::new(5, 0, 200); + let data = mech.create_initial_state(); + let mut arena = NodeArena::new(data); - let child1 = arena.alloc_child(0, [1, 2, 2], 0.5); - let child2 = arena.alloc_child(0, [0, 1, 2], 0.5); + let mask = mech.get_action_mask_immut(data); + let valid: Vec = mask + .iter() + .enumerate() + .filter_map(|(i, &v)| if v { Some(i) } else { None }) + .collect(); + let mut d1 = data; + mech.apply_action_index(&mut d1, valid[0]); + let mut d2 = data; + mech.apply_action_index(&mut d2, valid[1]); + let child1 = arena.alloc_child(0, valid[0], d1, 0.5); + let child2 = arena.alloc_child(0, valid[1], d2, 0.5); arena.get_mut(0).children = vec![child1, child2]; arena.get_mut(0).visit_count = 20; - // Give child1 higher Q-value and similar visits so it clearly wins + // Child1 higher Q with same visits arena.get_mut(child1).visit_count = 8; arena.get_mut(child1).value_sum = 6.0; // Q = 0.75 arena.get_mut(child2).visit_count = 8; @@ -631,24 +648,20 @@ mod tests { // Empty visited set (simulates penalize_visited_states=false) let visited = HashSet::new(); - - // Child1 should be selected (no penalty applied, higher Q) - let selected = select_child(&arena, 0, 1.4, &visited, 5); + let selected = select_child(&arena, 0, 1.4, &visited); assert_eq!(selected, child1); - // Now with visited_states containing child1's hash, - // child2 should be selected instead (penalty of -1.0 brings Q from 0.75 to -0.25) - let child1_game = state.clone_and_step([1, 2, 2]); + // Now with child1's data in the visited set, child2 wins let mut visited_with_penalty = HashSet::new(); - visited_with_penalty.insert(child1_game.get_fast_hash()); - - let selected = select_child(&arena, 0, 1.4, &visited_with_penalty, 5); + visited_with_penalty.insert(d1); + let selected = select_child(&arena, 0, 1.4, &visited_with_penalty); assert_eq!(selected, child2); } #[test] fn test_mcts_search_basic() { - let state = GameState::new(5, 0); // No walls for faster search + let mech = QGameMechanics::new(5, 0, 200); + let data = mech.create_initial_state(); let mut evaluator = MockEvaluator::new(0.0); let visited = HashSet::new(); @@ -659,22 +672,21 @@ mod tests { ..Default::default() }; - let result = search(&config, state, &mut evaluator, &visited); + let result = search(&config, data, &mech, &mut evaluator, &visited); assert!(result.is_ok()); let (children, _value) = result.unwrap(); - // Should have some children assert!(!children.is_empty()); - // Visit counts should be > 0 let total_visits: u32 = children.iter().map(|c| c.visit_count).sum(); assert!(total_visits > 0); } #[test] fn test_mcts_n_zero_uses_priors() { - let state = GameState::new(5, 3); + let mech = QGameMechanics::new(5, 3, 200); + let data = mech.create_initial_state(); let mut evaluator = MockEvaluator::new(0.0); let visited = HashSet::new(); @@ -685,12 +697,11 @@ mod tests { ..Default::default() }; - let result = search(&config, state, &mut evaluator, &visited); + let result = search(&config, data, &mech, &mut evaluator, &visited); assert!(result.is_ok()); let (children, _) = result.unwrap(); - // Visit counts should be proportional to priors (×1000) for child in &children { assert!(child.visit_count > 0); } @@ -703,20 +714,36 @@ mod tests { apply_dirichlet_noise(&mut priors, 0.25, 0.5); - // Priors should have changed let changed = priors .iter() .zip(original.iter()) .any(|(p, o)| (p - o).abs() > 1e-6); assert!(changed); - // Only non-zero priors should be affected assert!(priors[0] < 1e-6); assert!(priors[2] < 1e-6); assert!(priors[5] < 1e-6); - // Non-zero priors should sum to ~1 let sum: f32 = priors.iter().sum(); assert!((sum - 1.0).abs() < 0.1); } + + #[test] + fn test_child_action_index_roundtrips_to_action() { + let bs = 5; + let mech = QGameMechanics::new(5, 3, 200); + let data = mech.create_initial_state(); + let mut evaluator = MockEvaluator::new(0.0); + let visited = HashSet::new(); + let config = MCTSConfig { + n: Some(2), + ucb_c: 1.4, + noise_epsilon: 0.0, + ..Default::default() + }; + let (children, _) = search(&config, data, &mech, &mut evaluator, &visited).unwrap(); + for c in &children { + assert_eq!(action_to_index(bs, &c.action), c.action_index); + } + } } diff --git a/deep_quoridor/rust/src/agents/mod.rs b/deep_quoridor/rust/src/agents/mod.rs index 5033876f..a6797a05 100644 --- a/deep_quoridor/rust/src/agents/mod.rs +++ b/deep_quoridor/rust/src/agents/mod.rs @@ -2,7 +2,7 @@ //! //! All agents implement the [`ActionSelector`] trait. -use crate::game_state::GameState; +use crate::compact::q_game_mechanics::QGameMechanics; #[cfg(feature = "binary")] pub mod onnx_agent; @@ -16,20 +16,20 @@ pub struct ActionSelectionTrace { pub root_value: Option, } -/// Trait for agents that select actions given a game state. +/// Trait for agents that select actions given a compact game state. /// -/// The provided state may already be rotated for Player 1. +/// The provided state may already be rotated for Player 1 — agents should treat +/// `data` as the canonical state in whatever frame is supplied. pub trait ActionSelector { - /// Select an action given the game state. - /// - /// Arguments are in the coordinate frame presented (possibly rotated). + /// Select an action given the compact game state. /// /// Returns `(action_index, policy_probabilities)` where `action_index` is /// a flat index into the policy vector and `policy_probabilities` is the /// full softmax output (or a uniform/mask-based distribution for simpler agents). fn select_action( &mut self, - state: &GameState, + data: u64, + mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)>; diff --git a/deep_quoridor/rust/src/agents/onnx_agent.rs b/deep_quoridor/rust/src/agents/onnx_agent.rs index f976ffb0..117eae1c 100644 --- a/deep_quoridor/rust/src/agents/onnx_agent.rs +++ b/deep_quoridor/rust/src/agents/onnx_agent.rs @@ -6,8 +6,8 @@ use anyhow::{Context, Result}; use ort::session::Session; use crate::agents::ActionSelector; -use crate::game_state::GameState; -use crate::grid_helpers::grid_game_state_to_resnet_input; +use crate::compact::q_game_mechanics::QGameMechanics; +use crate::grid_helpers::compact_state_to_resnet_input; /// Compute softmax of a slice of logits. pub fn softmax(logits: &[f32]) -> Vec { @@ -36,11 +36,12 @@ impl OnnxAgent { impl ActionSelector for OnnxAgent { fn select_action( &mut self, - state: &GameState, + data: u64, + mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(usize, Vec)> { // Build ResNet input tensor - let resnet_input = grid_game_state_to_resnet_input(state); + let resnet_input = compact_state_to_resnet_input(mechanics, data); // Convert to flat vec for ORT let shape = resnet_input.shape().to_vec(); diff --git a/deep_quoridor/rust/src/agents/random_agent.rs b/deep_quoridor/rust/src/agents/random_agent.rs index 28bad030..46a08b1b 100644 --- a/deep_quoridor/rust/src/agents/random_agent.rs +++ b/deep_quoridor/rust/src/agents/random_agent.rs @@ -3,7 +3,7 @@ use rand::Rng; use crate::agents::ActionSelector; -use crate::game_state::GameState; +use crate::compact::q_game_mechanics::QGameMechanics; /// An agent that selects a random valid action. pub struct RandomAgent { @@ -27,7 +27,8 @@ impl Default for RandomAgent { impl ActionSelector for RandomAgent { fn select_action( &mut self, - _state: &GameState, + _data: u64, + _mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)> { // Collect valid action indices @@ -57,14 +58,20 @@ impl ActionSelector for RandomAgent { mod tests { use super::*; + fn make_mech() -> (QGameMechanics, u64) { + let mech = QGameMechanics::new(5, 3, 200); + let data = mech.create_initial_state(); + (mech, data) + } + #[test] fn test_random_agent_picks_valid_action() { let mut agent = RandomAgent::new(); - let state = GameState::new(5, 3); + let (mech, data) = make_mech(); let mask = vec![false, false, true, false, true, true]; for _ in 0..50 { - let (idx, _) = agent.select_action(&state, &mask).unwrap(); + let (idx, _) = agent.select_action(data, &mech, &mask).unwrap(); assert!( mask[idx], "RandomAgent picked an invalid action index {}", @@ -76,10 +83,10 @@ mod tests { #[test] fn test_random_agent_policy_sums_to_one() { let mut agent = RandomAgent::new(); - let state = GameState::new(5, 3); + let (mech, data) = make_mech(); let mask = vec![false, true, true, false, true]; - let (_, policy) = agent.select_action(&state, &mask).unwrap(); + let (_, policy) = agent.select_action(data, &mech, &mask).unwrap(); let sum: f32 = policy.iter().sum(); assert!((sum - 1.0).abs() < 1e-6); @@ -91,10 +98,10 @@ mod tests { #[test] fn test_random_agent_no_valid_actions_fails() { let mut agent = RandomAgent::new(); - let state = GameState::new(5, 3); + let (mech, data) = make_mech(); let mask = vec![false, false, false]; - let result = agent.select_action(&state, &mask); + let result = agent.select_action(data, &mech, &mask); assert!(result.is_err()); } } diff --git a/deep_quoridor/rust/src/compact/q_game_mechanics.rs b/deep_quoridor/rust/src/compact/q_game_mechanics.rs index 2cc1279d..8a4be1a3 100644 --- a/deep_quoridor/rust/src/compact/q_game_mechanics.rs +++ b/deep_quoridor/rust/src/compact/q_game_mechanics.rs @@ -588,6 +588,76 @@ impl QGameMechanics { pub fn display(&self, data: u64) -> String { self.repr.display(data) } + + /// Compute the full action mask for the current player. + /// + /// `data` is modified during the call (wall validation places-then-removes + /// walls in place) but restored before return; observers see no change. + pub fn get_action_mask(&self, data: &mut u64) -> Vec { + let bs = self.repr.board_size(); + let board_size_i = bs as i32; + let total = crate::actions::policy_size(board_size_i); + let mut mask = vec![false; total]; + + let num_moves = bs * bs; + for (r, c) in self.get_valid_moves(*data) { + mask[r * bs + c] = true; + } + + let ws = bs - 1; + let num_walls = ws * ws; + for (r, c, orientation) in self.get_valid_wall_placements(data) { + let wall_off = num_moves + orientation * num_walls + r * ws + c; + mask[wall_off] = true; + } + mask + } + + /// Immutable wrapper around `get_action_mask` for callers that hold `u64` by value. + pub fn get_action_mask_immut(&self, data: u64) -> Vec { + let mut d = data; + self.get_action_mask(&mut d) + } + + /// Apply a flat policy action index to `data`, mirroring `GameState::step`. + /// + /// Decodes the action, executes the appropriate move/wall placement, then + /// switches to the next player (which also increments `completed_steps`). + pub fn apply_action_index(&self, data: &mut u64, action_idx: usize) { + let bs = self.repr.board_size() as i32; + let action = crate::actions::action_index_to_action(bs, action_idx); + let (r, c, t) = (action[0] as usize, action[1] as usize, action[2]); + let player = self.repr.get_current_player(*data); + match t { + crate::actions::ACTION_MOVE => { + self.execute_move(data, player, r, c); + } + crate::actions::ACTION_WALL_VERTICAL => { + self.execute_wall_placement(data, player, r, c, WALL_VERTICAL); + } + crate::actions::ACTION_WALL_HORIZONTAL => { + self.execute_wall_placement(data, player, r, c, WALL_HORIZONTAL); + } + _ => panic!("Invalid action type: {}", t), + } + self.switch_player(data); + } + + /// Returns true if either player has reached their goal row. + pub fn is_game_over(&self, data: u64) -> bool { + self.check_win(data, 0) || self.check_win(data, 1) + } + + /// Returns `Some(player)` if a player has won, `None` otherwise. + pub fn winner(&self, data: u64) -> Option { + if self.check_win(data, 0) { + Some(0) + } else if self.check_win(data, 1) { + Some(1) + } else { + None + } + } } #[cfg(test)] diff --git a/deep_quoridor/rust/src/game_runner.rs b/deep_quoridor/rust/src/game_runner.rs index 0b931eee..b737c363 100644 --- a/deep_quoridor/rust/src/game_runner.rs +++ b/deep_quoridor/rust/src/game_runner.rs @@ -11,10 +11,10 @@ use crate::actions::{ action_index_to_action, ACTION_MOVE, ACTION_WALL_HORIZONTAL, ACTION_WALL_VERTICAL, }; use crate::agents::ActionSelector; +use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; -use crate::grid::CELL_WALL; -use crate::grid_helpers::grid_game_state_to_resnet_input; -use crate::rotation::{build_rotated_state, create_rotation_mapping, remap_policy}; +use crate::grid_helpers::compact_state_to_resnet_input; +use crate::rotation::{create_rotation_mapping, remap_policy, rotate_compact_state}; pub trait PlayGameObserver { fn on_state_snapshot(&mut self, step: usize, state: &GameState, action_mask: &[bool]); @@ -38,86 +38,27 @@ fn format_action(_board_size: i32, row: i32, col: i32, action_type: i32) -> Stri } } -/// Render the board state as a human-readable string. +/// Build a transient `GameState` from a compact `u64` state. /// -/// Shows player positions as `1` and `2`, walls as `|` (vertical) and `-` -/// (horizontal), and empty cells as `.`. -/// -/// The board is always shown in the original (un-rotated) orientation. -pub fn display_board( - grid: &ndarray::ArrayView2, - player_positions: &ndarray::ArrayView2, - walls_remaining: &ndarray::ArrayView1, +/// Used only to feed `PlayGameObserver::on_state_snapshot`, which keeps its +/// `&GameState` signature so the cross-language trace observer in +/// `python_consistency.rs` stays untouched. Allocates per call (per step); only +/// called when an observer is attached. Production self-play does not attach an +/// observer. +fn compact_to_game_state( + mechanics: &QGameMechanics, + data: u64, board_size: i32, -) -> String { - let mut out = String::new(); - let bs = board_size as usize; - - // Column header - out.push_str(" "); - for c in 0..bs { - out.push_str(&format!(" {} ", c)); - } - out.push('\n'); - - let p0_row = player_positions[[0, 0]] as usize; - let p0_col = player_positions[[0, 1]] as usize; - let p1_row = player_positions[[1, 0]] as usize; - let p1_col = player_positions[[1, 1]] as usize; - - for row in 0..bs { - // --- cell row --- - out.push_str(&format!("{:>3} ", row)); - for col in 0..bs { - // cell content - if row == p0_row && col == p0_col { - out.push('1'); - } else if row == p1_row && col == p1_col { - out.push('2'); - } else { - out.push('.'); - } - - // vertical wall to the right - if col < bs - 1 { - // Grid coord of the gap between (row,col) and (row,col+1) - let gr = (row * 2 + 2) as usize; - let gc = (col * 2 + 3) as usize; - if grid[[gr, gc]] == CELL_WALL { - out.push_str(" | "); - } else { - out.push_str(" "); - } - } - } - // Metadata on the right of first two rows - match row { - 0 => out.push_str(&format!(" P1 walls: {}", walls_remaining[0])), - 1 => out.push_str(&format!(" P2 walls: {}", walls_remaining[1])), - _ => {} - } - out.push('\n'); - - // --- horizontal wall row between this row and the next --- - if row < bs - 1 { - out.push_str(" "); - for col in 0..bs { - // Grid coord of the gap between (row,col) and (row+1,col) - let gr = (row * 2 + 3) as usize; - let gc = (col * 2 + 2) as usize; - if grid[[gr, gc]] == CELL_WALL { - out.push('-'); - } else { - out.push(' '); - } - if col < bs - 1 { - out.push_str(" "); - } - } - out.push('\n'); - } - } - out + max_walls: i32, +) -> GameState { + let repr = mechanics.repr(); + let mut state = GameState::new(board_size, max_walls); + state.grid = repr.to_grid(data); + state.player_positions = repr.to_player_positions(data); + state.walls_remaining = repr.to_walls_remaining(data); + state.current_player = repr.get_current_player(data) as i32; + state.completed_steps = repr.get_completed_steps(data); + state } /// One turn's training data, stored in "current-player-faces-downward" coords. @@ -144,15 +85,14 @@ pub struct GameResult { pub replay_items: Vec, } -/// Play a complete game between two agents. +/// Play a complete game between two agents using the compact `u64` state. /// /// `agent_p1` controls player 0 and `agent_p2` controls player 1. /// Player 0 moves first. Action selection runs in original orientation; any /// current-player-downward rotation is handled inside evaluator codepaths. /// /// When `trace` is `true`, each step prints whose turn it is, the action -/// chosen, and the resulting board state in the original (un-rotated) -/// orientation. +/// chosen, and the resulting board state via `QGameMechanics::display`. pub fn play_game( agent_p1: &mut dyn ActionSelector, agent_p2: &mut dyn ActionSelector, @@ -162,7 +102,12 @@ pub fn play_game( trace: bool, mut observer: Option<&mut dyn PlayGameObserver>, ) -> anyhow::Result { - let mut state = GameState::new(board_size, max_walls); + let mechanics = QGameMechanics::new( + board_size as usize, + max_walls as usize, + max_steps as usize, + ); + let mut data = mechanics.create_initial_state(); let (original_to_rotated, _) = create_rotation_mapping(board_size); let mut replay_items: Vec = Vec::new(); @@ -171,25 +116,22 @@ pub fn play_game( let mut emitted_terminal_snapshot = false; for step in 0..max_steps { - let current_player = state.current_player; - - // Match Python: run action selection in original orientation. - let work_state = state.clone(); - let mask = work_state.get_action_mask(); + let current_player = mechanics.repr().get_current_player(data) as i32; + let mask = mechanics.get_action_mask_immut(data); if let Some(obs) = observer.as_deref_mut() { + let state = compact_to_game_state(&mechanics, data, board_size, max_walls); obs.on_state_snapshot(step as usize, &state, &mask); } // Check for no valid actions (shouldn't happen in Quoridor, but be safe) if !mask.iter().any(|&m| m) { - // Truncate emitted_terminal_snapshot = true; break; } - // Build ResNet input from the working state - let resnet_input = grid_game_state_to_resnet_input(&work_state); + // Build ResNet input from the working state (for current-player frame storage). + let resnet_input = compact_state_to_resnet_input(&mechanics, data); // Ask the appropriate agent for action let agent: &mut dyn ActionSelector = if current_player == 0 { @@ -197,7 +139,7 @@ pub fn play_game( } else { agent_p2 }; - let (action_idx, policy) = agent.select_action(&work_state, &mask)?; + let (action_idx, policy) = agent.select_action(data, &mechanics, &mask)?; let root_value = agent .last_selection_trace() .and_then(|trace| trace.root_value); @@ -208,12 +150,12 @@ pub fn play_game( // Match Python storage semantics: replay is stored in current-player-downward frame. let (stored_input_3d, stored_policy, stored_mask) = if current_player == 1 { - let rotated_state = build_rotated_state(&state); - let rotated_input = grid_game_state_to_resnet_input(&rotated_state) + let rotated_data = rotate_compact_state(&mechanics, data); + let rotated_input = compact_state_to_resnet_input(&mechanics, rotated_data) .index_axis(ndarray::Axis(0), 0) .to_owned(); let rotated_policy = remap_policy(&policy, &original_to_rotated); - let rotated_mask = rotated_state.get_action_mask(); + let rotated_mask = mechanics.get_action_mask_immut(rotated_data); (rotated_input, rotated_policy, rotated_mask) } else { ( @@ -231,35 +173,28 @@ pub fn play_game( player: current_player, }); - // Decode action index → (row, col, type) in working frame - let action_triple = action_index_to_action(board_size, action_idx); - - let (a_row, a_col, a_type) = (action_triple[0], action_triple[1], action_triple[2]); - - // Apply action on the ORIGINAL game state - state.step([a_row, a_col, a_type]); + // Apply action on the canonical u64 state + mechanics.apply_action_index(&mut data, action_idx); if trace { let player_label = if current_player == 0 { "P1" } else { "P2" }; + let action_triple = action_index_to_action(board_size, action_idx); println!( "--- Step {} | {} ---\n{}", step + 1, player_label, - format_action(board_size, a_row, a_col, a_type), - ); - print!( - "{}\n", - display_board( - &state.grid(), - &state.player_positions(), - &state.walls_remaining(), - board_size + format_action( + board_size, + action_triple[0], + action_triple[1], + action_triple[2] ), ); + print!("{}\n", mechanics.display(data)); } - // Check win (current_player already switched after step, so check previous player) - if state.check_win(current_player) { + // Check win (current_player already switched after apply, so check previous player) + if mechanics.check_win(data, current_player as usize) { winner = Some(current_player); // Backfill values: +1 for winner, -1 for loser for item in replay_items.iter_mut() { @@ -277,11 +212,12 @@ pub fn play_game( } } - if winner.is_none() && !emitted_terminal_snapshot && state.completed_steps >= max_steps as usize - { - let mask = state.get_action_mask(); + let completed_steps = mechanics.repr().get_completed_steps(data); + if winner.is_none() && !emitted_terminal_snapshot && completed_steps >= max_steps as usize { + let mask = mechanics.get_action_mask_immut(data); if let Some(obs) = observer.as_deref_mut() { - obs.on_state_snapshot(state.completed_steps, &state, &mask); + let state = compact_to_game_state(&mechanics, data, board_size, max_walls); + obs.on_state_snapshot(completed_steps, &state, &mask); } } @@ -304,7 +240,8 @@ mod tests { impl ActionSelector for FirstValidAgent { fn select_action( &mut self, - _state: &GameState, + _data: u64, + _mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)> { let idx = action_mask @@ -323,7 +260,6 @@ mod tests { let mut p2 = FirstValidAgent; let result = play_game(&mut p1, &mut p2, 5, 3, 200, false, None).unwrap(); - // Game should complete within 200 steps on a 5×5 board assert!(result.num_turns > 0); assert!(!result.replay_items.is_empty()); } @@ -334,8 +270,6 @@ mod tests { let mut p2 = FirstValidAgent; let result = play_game(&mut p1, &mut p2, 5, 0, 200, false, None).unwrap(); - // With 0 walls the game should end quickly via moves only - // Players should alternate for (i, item) in result.replay_items.iter().enumerate() { assert_eq!(item.player, (i as i32) % 2); } @@ -362,7 +296,6 @@ mod tests { fn test_play_game_truncation_values() { let mut p1 = FirstValidAgent; let mut p2 = FirstValidAgent; - // Very short max_steps to force truncation let result = play_game(&mut p1, &mut p2, 5, 3, 2, false, None).unwrap(); if result.winner.is_none() { diff --git a/deep_quoridor/rust/src/grid_helpers.rs b/deep_quoridor/rust/src/grid_helpers.rs index 91dce6ac..b240e92b 100644 --- a/deep_quoridor/rust/src/grid_helpers.rs +++ b/deep_quoridor/rust/src/grid_helpers.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +use crate::compact::q_bit_repr::{WALL_HORIZONTAL, WALL_VERTICAL}; +use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; use crate::grid::CELL_WALL; @@ -63,13 +65,118 @@ pub fn grid_game_state_to_resnet_input(state: &GameState) -> ndarray::Array4 ndarray::Array4 { + let repr = mechanics.repr(); + let bs = repr.board_size(); + let grid_size = bs * 2 + 3; + let current_player = repr.get_current_player(data); + let opponent = 1 - current_player; + + let mut input = ndarray::Array4::::zeros((1, 5, grid_size, grid_size)); + + // Channel 0: border walls (rows/cols 0,1 and last two) + for i in 0..grid_size { + for j in 0..2 { + input[[0, 0, j, i]] = 1.0; + input[[0, 0, grid_size - 1 - j, i]] = 1.0; + input[[0, 0, i, j]] = 1.0; + input[[0, 0, i, grid_size - 1 - j]] = 1.0; + } + } + // Channel 0: placed walls. Mirrors set_wall_cells layout: + // Vertical at (r,c) → grid cells (r*2+2, c*2+3), (r*2+3, c*2+3), (r*2+4, c*2+3) + // Horizontal at (r,c) → grid cells (r*2+3, c*2+2), (r*2+3, c*2+3), (r*2+3, c*2+4) + for r in 0..bs - 1 { + for c in 0..bs - 1 { + if repr.get_wall(data, r, c, WALL_VERTICAL) { + let gc = c * 2 + 3; + for dr in 0..3 { + input[[0, 0, r * 2 + 2 + dr, gc]] = 1.0; + } + } + if repr.get_wall(data, r, c, WALL_HORIZONTAL) { + let gr = r * 2 + 3; + for dc in 0..3 { + input[[0, 0, gr, c * 2 + 2 + dc]] = 1.0; + } + } + } + } + + // Channels 1/2: current/opponent position as 1-hot at (r*2+2, c*2+2). + let (cr, cc) = repr.get_player_position(data, current_player); + input[[0, 1, cr * 2 + 2, cc * 2 + 2]] = 1.0; + let (or_, oc) = repr.get_player_position(data, opponent); + input[[0, 2, or_ * 2 + 2, oc * 2 + 2]] = 1.0; + + // Channels 3/4: walls remaining broadcast. + let my_w = repr.get_walls_remaining(data, current_player) as f32; + let opp_w = repr.get_walls_remaining(data, opponent) as f32; + input.slice_mut(ndarray::s![0, 3, .., ..]).fill(my_w); + input.slice_mut(ndarray::s![0, 4, .., ..]).fill(opp_w); + + input +} + #[cfg(test)] mod tests { use super::*; + use crate::actions::ACTION_MOVE; use crate::game_state::{create_initial_state, GameState}; use crate::grid::{set_wall_cells, CELL_FREE, CELL_WALL}; use ndarray::Array1; + #[test] + fn test_compact_resnet_matches_gamestate_resnet_initial() { + // QBitRepr packs state into a u64; only smaller boards fit. + for bs in [3, 5] { + let state = GameState::new(bs, 3); + let mech = QGameMechanics::new(bs as usize, 3, 200); + let data = mech.create_initial_state(); + let a = grid_game_state_to_resnet_input(&state); + let b = compact_state_to_resnet_input(&mech, data); + assert_eq!(a, b, "initial state mismatch bs={bs}"); + } + } + + #[test] + fn test_compact_resnet_matches_after_stepping() { + // Run a small deterministic action sequence on both representations. + // No walls → only move actions, which are trivially valid. + let bs = 5; + let mut state = GameState::new(bs, 0); + let mech = QGameMechanics::new(bs as usize, 0, 200); + let mut data = mech.create_initial_state(); + + let actions: Vec<[i32; 3]> = vec![ + [1, 2, ACTION_MOVE], + [3, 2, ACTION_MOVE], + [1, 1, ACTION_MOVE], + [3, 1, ACTION_MOVE], + ]; + + let a0 = grid_game_state_to_resnet_input(&state); + let b0 = compact_state_to_resnet_input(&mech, data); + assert_eq!(a0, b0, "step 0"); + + for (i, action) in actions.iter().enumerate() { + state.step(*action); + let action_idx = crate::actions::action_to_index(bs, action); + mech.apply_action_index(&mut data, action_idx); + let a = grid_game_state_to_resnet_input(&state); + let b = compact_state_to_resnet_input(&mech, data); + assert_eq!(a, b, "step {}", i + 1); + } + } + #[test] fn test_resnet_input_shape() { let state = GameState::new(5, 3); diff --git a/deep_quoridor/rust/src/python_consistency.rs b/deep_quoridor/rust/src/python_consistency.rs index 9b5ef429..87d5b7eb 100644 --- a/deep_quoridor/rust/src/python_consistency.rs +++ b/deep_quoridor/rust/src/python_consistency.rs @@ -515,9 +515,24 @@ fn generate_rust_mcts_trace( // Run MCTS on the original (unrotated) state, matching Python's behaviour where // the evaluator handles rotation internally and MCTS always operates in the - // original action-index space for both players. + // original action-index space for both players. Convert the GameState to the + // compact (u64, mechanics) form that MCTS now uses. + let mechanics = crate::compact::q_game_mechanics::QGameMechanics::new( + board_size as usize, + max_walls as usize, + max_steps as usize, + ); + let mut data = mechanics.repr().create_data(); + mechanics.repr().from_game_state( + &mut data, + &state.grid(), + &state.player_positions(), + &state.walls_remaining(), + state.current_player, + state.completed_steps as i32, + ); let (children, root_value): (Vec, f32) = - search(&config, state.clone(), &mut evaluator, &visited_states) + search(&config, data, &mechanics, &mut evaluator, &visited_states) .expect("MCTS search should succeed"); let visit_counts: Vec = children.iter().map(|c| c.visit_count).collect(); diff --git a/deep_quoridor/rust/src/rotation.rs b/deep_quoridor/rust/src/rotation.rs index 498ee908..62e137fc 100644 --- a/deep_quoridor/rust/src/rotation.rs +++ b/deep_quoridor/rust/src/rotation.rs @@ -10,6 +10,8 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2}; use crate::actions::{action_index_to_action, action_to_index, policy_size, ACTION_MOVE}; +use crate::compact::q_bit_repr::{WALL_HORIZONTAL, WALL_VERTICAL}; +use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; /// Rotate a 2D grid 180° — equivalent to `np.rot90(grid, k=2)`. @@ -115,6 +117,40 @@ pub fn remap_mask(mask: &[bool], mapping: &[usize]) -> Vec { out } +/// Build a 180°-rotated u64 game state. +/// +/// Mirrors `build_rotated_state` semantics for the compact representation: +/// - Player positions flipped: `(r,c) -> (bs-1-r, bs-1-c)` for both players. +/// - All wall positions flipped: `(r,c,o) -> (ws-1-r, ws-1-c, o)` where `ws = bs-1`. +/// Orientation is preserved under 180° rotation. +/// - Walls remaining, current player, completed steps: unchanged. +pub fn rotate_compact_state(mechanics: &QGameMechanics, data: u64) -> u64 { + let repr = mechanics.repr(); + let bs = repr.board_size(); + let ws = bs - 1; + let mut out = repr.create_data(); + + for p in 0..2 { + let (r, c) = repr.get_player_position(data, p); + repr.set_player_position(&mut out, p, bs - 1 - r, bs - 1 - c); + } + for r in 0..ws { + for c in 0..ws { + for o in [WALL_VERTICAL, WALL_HORIZONTAL] { + if repr.get_wall(data, r, c, o) { + repr.set_wall(&mut out, ws - 1 - r, ws - 1 - c, o, true); + } + } + } + } + repr.set_walls_remaining(&mut out, 0, repr.get_walls_remaining(data, 0)); + repr.set_walls_remaining(&mut out, 1, repr.get_walls_remaining(data, 1)); + repr.set_current_player(&mut out, repr.get_current_player(data)); + repr.set_completed_steps(&mut out, repr.get_completed_steps(data)); + + out +} + #[cfg(test)] mod tests { use super::*; @@ -218,6 +254,24 @@ mod tests { } } + #[test] + fn test_rotate_compact_state_matches_build_rotated_state_initial() { + // QBitRepr packs state into a u64; only smaller boards fit. + for bs in [3, 5] { + let state = GameState::new(bs, 3); + let rotated_state = build_rotated_state(&state); + + let mech = QGameMechanics::new(bs as usize, 3, 200); + let data = mech.create_initial_state(); + let rotated_data = rotate_compact_state(&mech, data); + + // Compare via ResNet input tensors. + let a = crate::grid_helpers::grid_game_state_to_resnet_input(&rotated_state); + let b = crate::grid_helpers::compact_state_to_resnet_input(&mech, rotated_data); + assert_eq!(a, b, "rotation mismatch on initial state for bs={bs}"); + } + } + #[test] fn test_remap_policy_roundtrip() { let bs = 5; From 536ce0ae56486185a7b818be343de046bf29eef4 Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sun, 17 May 2026 14:21:48 -0400 Subject: [PATCH 4/6] Fix rotated action mask to remap from original instead of re-validating QGameMechanics owns goal_rows and does not flip them under rotation, so get_action_mask_immut on rotated data treats walls that block the rotated player's path as legal. Use remap_mask on the original mask instead, and add a test that pins this contract. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../rust/src/agents/alphazero/evaluator.rs | 33 ++++++++++--------- deep_quoridor/rust/src/game_runner.rs | 8 +++-- deep_quoridor/rust/src/rotation.rs | 33 +++++++++++++++++++ 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs index eab7e788..5c3c08a7 100644 --- a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs +++ b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs @@ -8,7 +8,7 @@ use ort::session::Session; use crate::agents::onnx_agent::softmax; use crate::compact::q_game_mechanics::QGameMechanics; use crate::grid_helpers::compact_state_to_resnet_input; -use crate::rotation::{create_rotation_mapping, remap_policy, rotate_compact_state}; +use crate::rotation::{create_rotation_mapping, remap_mask, remap_policy, rotate_compact_state}; /// Trait for evaluating game positions. /// @@ -28,7 +28,7 @@ pub trait Evaluator { /// returning both a value estimate and policy priors. pub struct OnnxEvaluator { session: Session, - rotated_to_original_by_board_size: HashMap>, + rotation_mappings_by_board_size: HashMap, Vec)>, } /// Deterministic evaluator for cross-language consistency tests. @@ -45,7 +45,7 @@ impl OnnxEvaluator { .context("Failed to load ONNX model")?; Ok(Self { session, - rotated_to_original_by_board_size: HashMap::new(), + rotation_mappings_by_board_size: HashMap::new(), }) } } @@ -60,19 +60,22 @@ impl Evaluator for OnnxEvaluator { let bs = mechanics.repr().board_size() as i32; let current_player = mechanics.repr().get_current_player(data); - let rotated_to_original = self - .rotated_to_original_by_board_size + let mappings = self + .rotation_mappings_by_board_size .entry(bs) - .or_insert_with(|| create_rotation_mapping(bs).1); - - let (work_data, work_action_mask, rotated_to_original) = if current_player == 1 { + .or_insert_with(|| create_rotation_mapping(bs)); + let (orig_to_rot, rot_to_orig) = (&mappings.0, &mappings.1); + + // For player 1, the network always sees the board rotated 180° so the + // current player faces downward. The compact rotation matches the tensor + // of `build_rotated_state`, but `QGameMechanics::goal_rows` is owned by + // the mechanics and does not flip — so wall-mask validation on rotated + // data treats walls that block the rotated player's path as legal. Remap + // the (correct) original mask into rotated index space instead. + let (work_data, work_action_mask, rot_to_orig_slice) = if current_player == 1 { let rotated_data = rotate_compact_state(mechanics, data); - let rotated_mask = mechanics.get_action_mask_immut(rotated_data); - ( - rotated_data, - rotated_mask, - Some(rotated_to_original.as_slice()), - ) + let rotated_mask = remap_mask(action_mask, orig_to_rot); + (rotated_data, rotated_mask, Some(rot_to_orig.as_slice())) } else { (data, action_mask.to_vec(), None) }; @@ -105,7 +108,7 @@ impl Evaluator for OnnxEvaluator { // Apply masked softmax to get priors let priors_work = masked_softmax(policy_logits.1, &work_action_mask); - let priors = if let Some(rot_to_orig) = rotated_to_original { + let priors = if let Some(rot_to_orig) = rot_to_orig_slice { remap_policy(&priors_work, rot_to_orig) } else { priors_work diff --git a/deep_quoridor/rust/src/game_runner.rs b/deep_quoridor/rust/src/game_runner.rs index b737c363..ab073709 100644 --- a/deep_quoridor/rust/src/game_runner.rs +++ b/deep_quoridor/rust/src/game_runner.rs @@ -14,7 +14,7 @@ use crate::agents::ActionSelector; use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; use crate::grid_helpers::compact_state_to_resnet_input; -use crate::rotation::{create_rotation_mapping, remap_policy, rotate_compact_state}; +use crate::rotation::{create_rotation_mapping, remap_mask, remap_policy, rotate_compact_state}; pub trait PlayGameObserver { fn on_state_snapshot(&mut self, step: usize, state: &GameState, action_mask: &[bool]); @@ -155,7 +155,11 @@ pub fn play_game( .index_axis(ndarray::Axis(0), 0) .to_owned(); let rotated_policy = remap_policy(&policy, &original_to_rotated); - let rotated_mask = mechanics.get_action_mask_immut(rotated_data); + // `QGameMechanics::goal_rows` does not flip under rotation, so the + // mask must come from remapping the original mask, not from + // re-validating on rotated data (which would treat walls that + // block the rotated player's path as legal). + let rotated_mask = remap_mask(&mask, &original_to_rotated); (rotated_input, rotated_policy, rotated_mask) } else { ( diff --git a/deep_quoridor/rust/src/rotation.rs b/deep_quoridor/rust/src/rotation.rs index 62e137fc..543688e7 100644 --- a/deep_quoridor/rust/src/rotation.rs +++ b/deep_quoridor/rust/src/rotation.rs @@ -272,6 +272,39 @@ mod tests { } } + /// Pins the contract that callers must NOT compute the rotated action mask + /// via `mechanics.get_action_mask_immut(rotate_compact_state(data))`. + /// `QGameMechanics` owns `goal_rows` and does not flip them under rotation, + /// so wall placements that would block the rotated player's path are + /// silently treated as legal. The correct rotated mask is + /// `remap_mask(original_mask, original_to_rotated)`. + #[test] + fn test_remap_mask_matches_gamestate_rotated_mask_with_walls() { + use crate::compact::q_game_mechanics::QGameMechanics; + + let bs = 5; + let mut state = GameState::new(bs, 2); + let mech = QGameMechanics::new(bs as usize, 2, 200); + let mut data = mech.create_initial_state(); + + // A sequence that places two vertical walls and leaves it as player 1's + // turn — the configuration where the goal-row bug surfaces. + let action_indices: [usize; 5] = [1, 23, 25, 27, 2]; + for &ai in &action_indices { + let action = action_index_to_action(bs, ai); + state.step(action); + mech.apply_action_index(&mut data, ai); + } + assert_eq!(state.current_player, 1); + + let original_mask = mech.get_action_mask_immut(data); + let (orig_to_rot, _) = create_rotation_mapping(bs); + let remapped = remap_mask(&original_mask, &orig_to_rot); + + let gs_rotated_mask = build_rotated_state(&state).get_action_mask(); + assert_eq!(remapped, gs_rotated_mask); + } + #[test] fn test_remap_policy_roundtrip() { let bs = 5; From 6b05fd701549a5706e345bd6308958e4c18ac4d0 Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sun, 17 May 2026 16:18:43 -0400 Subject: [PATCH 5/6] Widen compact state to 24-byte CompactState (u128 walls + u64 scalars) The single-u64 packed state could not hold a 9x9 board (128 wall bits alone exceed u64). Split the layout so the wall bitmap lives in a u128 and the scalar fields in a u64; every accessor stays within one primitive. Policy DB schema switches to FixedSizeBinary(24) with lex byte ordering; PyO3 functions accept/return state as 24-byte bytes buffers. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../rust/src/agents/alphazero/agent.rs | 9 +- .../rust/src/agents/alphazero/evaluator.rs | 7 +- .../rust/src/agents/alphazero/mcts.rs | 21 +- deep_quoridor/rust/src/agents/mod.rs | 3 +- deep_quoridor/rust/src/agents/onnx_agent.rs | 3 +- deep_quoridor/rust/src/agents/random_agent.rs | 5 +- deep_quoridor/rust/src/compact/policy_db.rs | 173 +++++++---- deep_quoridor/rust/src/compact/q_bit_repr.rs | 276 +++++++++++++----- .../src/compact/q_bit_repr_conversions.rs | 10 +- .../rust/src/compact/q_game_mechanics.rs | 48 +-- deep_quoridor/rust/src/compact/q_minimax.rs | 16 +- deep_quoridor/rust/src/game_runner.rs | 11 +- deep_quoridor/rust/src/grid_helpers.rs | 9 +- deep_quoridor/rust/src/lib.rs | 98 +++++-- deep_quoridor/rust/src/rotation.rs | 9 +- 15 files changed, 467 insertions(+), 231 deletions(-) diff --git a/deep_quoridor/rust/src/agents/alphazero/agent.rs b/deep_quoridor/rust/src/agents/alphazero/agent.rs index f3f9371b..1609d3b1 100644 --- a/deep_quoridor/rust/src/agents/alphazero/agent.rs +++ b/deep_quoridor/rust/src/agents/alphazero/agent.rs @@ -8,6 +8,7 @@ use anyhow::Result; use rand::Rng; use crate::agents::{ActionSelectionTrace, ActionSelector}; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; use super::evaluator::OnnxEvaluator; @@ -118,7 +119,7 @@ pub fn apply_temperature_and_sample( pub struct AlphaZeroAgent { evaluator: OnnxEvaluator, config: AlphaZeroAgentConfig, - visited_states: HashSet, + visited_states: HashSet, last_selection_trace: Option, } @@ -144,7 +145,7 @@ impl AlphaZeroAgent { impl ActionSelector for AlphaZeroAgent { fn select_action( &mut self, - data: u64, + data: CompactState, mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(usize, Vec)> { @@ -197,8 +198,8 @@ impl ActionSelector for AlphaZeroAgent { } // Optionally add the resulting state to the visited set. - // Keyed on raw u64 state including completed_steps; a position seen at - // two different step counts is treated as two states. + // Keyed on the compact state including completed_steps; a position seen + // at two different step counts is treated as two states. if self.config.penalize_visited_states { let mut next_data = data; mechanics.apply_action_index(&mut next_data, selected_idx); diff --git a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs index 5c3c08a7..aa22e313 100644 --- a/deep_quoridor/rust/src/agents/alphazero/evaluator.rs +++ b/deep_quoridor/rust/src/agents/alphazero/evaluator.rs @@ -6,6 +6,7 @@ use anyhow::{Context, Result}; use ort::session::Session; use crate::agents::onnx_agent::softmax; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; use crate::grid_helpers::compact_state_to_resnet_input; use crate::rotation::{create_rotation_mapping, remap_mask, remap_policy, rotate_compact_state}; @@ -16,7 +17,7 @@ use crate::rotation::{create_rotation_mapping, remap_mask, remap_policy, rotate_ pub trait Evaluator { fn evaluate( &mut self, - data: u64, + data: CompactState, mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(f32, Vec)>; @@ -53,7 +54,7 @@ impl OnnxEvaluator { impl Evaluator for OnnxEvaluator { fn evaluate( &mut self, - data: u64, + data: CompactState, mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(f32, Vec)> { @@ -121,7 +122,7 @@ impl Evaluator for OnnxEvaluator { impl Evaluator for UniformMockEvaluator { fn evaluate( &mut self, - _data: u64, + _data: CompactState, _mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(f32, Vec)> { diff --git a/deep_quoridor/rust/src/agents/alphazero/mcts.rs b/deep_quoridor/rust/src/agents/alphazero/mcts.rs index 27655a80..aea711ef 100644 --- a/deep_quoridor/rust/src/agents/alphazero/mcts.rs +++ b/deep_quoridor/rust/src/agents/alphazero/mcts.rs @@ -9,6 +9,7 @@ use rand_distr::{Dirichlet, Distribution}; use crate::actions::action_index_to_action; #[cfg(test)] use crate::actions::action_to_index; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; use super::evaluator::Evaluator; @@ -61,7 +62,7 @@ pub struct ChildInfo { #[derive(Debug)] pub struct Node { /// The compact game state at this node. - pub data: u64, + pub data: CompactState, /// Parent node index in the arena, None for root. pub parent: Option, /// Flat policy index for the action taken from parent to reach this node. @@ -82,7 +83,7 @@ pub struct Node { impl Node { /// Create a new root node. - pub fn new_root(data: u64) -> Self { + pub fn new_root(data: CompactState) -> Self { Self { data, parent: None, @@ -97,7 +98,7 @@ impl Node { } /// Create a new child node with its (already computed) state. - pub fn new_child(parent: usize, action_index: usize, data: u64, prior: f32) -> Self { + pub fn new_child(parent: usize, action_index: usize, data: CompactState, prior: f32) -> Self { Self { data, parent: Some(parent), @@ -133,7 +134,7 @@ pub struct NodeArena { impl NodeArena { /// Create a new arena with a root node. - pub fn new(root_data: u64) -> Self { + pub fn new(root_data: CompactState) -> Self { let root = Node::new_root(root_data); Self { nodes: vec![root] } } @@ -143,7 +144,7 @@ impl NodeArena { &mut self, parent: usize, action_index: usize, - data: u64, + data: CompactState, prior: f32, ) -> usize { let idx = self.nodes.len(); @@ -196,7 +197,7 @@ pub fn select_child( arena: &NodeArena, node_idx: usize, ucb_c: f32, - visited_states: &HashSet, + visited_states: &HashSet, ) -> usize { let node = arena.get(node_idx); let parent_visits = node.visit_count.max(1) as f32; @@ -294,10 +295,10 @@ pub fn apply_dirichlet_noise(priors: &mut [f32], epsilon: f32, alpha: f32) { /// Run MCTS search and return child information. pub fn search( config: &MCTSConfig, - root_data: u64, + root_data: CompactState, mechanics: &QGameMechanics, evaluator: &mut E, - visited_states: &HashSet, + visited_states: &HashSet, ) -> anyhow::Result<(Vec, f32)> { let bs = mechanics.repr().board_size() as i32; let mut arena = NodeArena::new(root_data); @@ -433,7 +434,7 @@ mod tests { impl Evaluator for MockEvaluator { fn evaluate( &mut self, - _data: u64, + _data: CompactState, _mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(f32, Vec)> { @@ -452,7 +453,7 @@ mod tests { } } - fn make_mech_state() -> (QGameMechanics, u64) { + fn make_mech_state() -> (QGameMechanics, CompactState) { let mech = QGameMechanics::new(5, 3, 200); let data = mech.create_initial_state(); (mech, data) diff --git a/deep_quoridor/rust/src/agents/mod.rs b/deep_quoridor/rust/src/agents/mod.rs index a6797a05..6e6db4f4 100644 --- a/deep_quoridor/rust/src/agents/mod.rs +++ b/deep_quoridor/rust/src/agents/mod.rs @@ -2,6 +2,7 @@ //! //! All agents implement the [`ActionSelector`] trait. +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; #[cfg(feature = "binary")] @@ -28,7 +29,7 @@ pub trait ActionSelector { /// full softmax output (or a uniform/mask-based distribution for simpler agents). fn select_action( &mut self, - data: u64, + data: CompactState, mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)>; diff --git a/deep_quoridor/rust/src/agents/onnx_agent.rs b/deep_quoridor/rust/src/agents/onnx_agent.rs index 117eae1c..17280729 100644 --- a/deep_quoridor/rust/src/agents/onnx_agent.rs +++ b/deep_quoridor/rust/src/agents/onnx_agent.rs @@ -6,6 +6,7 @@ use anyhow::{Context, Result}; use ort::session::Session; use crate::agents::ActionSelector; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; use crate::grid_helpers::compact_state_to_resnet_input; @@ -36,7 +37,7 @@ impl OnnxAgent { impl ActionSelector for OnnxAgent { fn select_action( &mut self, - data: u64, + data: CompactState, mechanics: &QGameMechanics, action_mask: &[bool], ) -> Result<(usize, Vec)> { diff --git a/deep_quoridor/rust/src/agents/random_agent.rs b/deep_quoridor/rust/src/agents/random_agent.rs index 46a08b1b..0ab41ede 100644 --- a/deep_quoridor/rust/src/agents/random_agent.rs +++ b/deep_quoridor/rust/src/agents/random_agent.rs @@ -3,6 +3,7 @@ use rand::Rng; use crate::agents::ActionSelector; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; /// An agent that selects a random valid action. @@ -27,7 +28,7 @@ impl Default for RandomAgent { impl ActionSelector for RandomAgent { fn select_action( &mut self, - _data: u64, + _data: CompactState, _mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)> { @@ -58,7 +59,7 @@ impl ActionSelector for RandomAgent { mod tests { use super::*; - fn make_mech() -> (QGameMechanics, u64) { + fn make_mech() -> (QGameMechanics, CompactState) { let mech = QGameMechanics::new(5, 3, 200); let data = mech.create_initial_state(); (mech, data) diff --git a/deep_quoridor/rust/src/compact/policy_db.rs b/deep_quoridor/rust/src/compact/policy_db.rs index e5cfa160..52d919de 100644 --- a/deep_quoridor/rust/src/compact/policy_db.rs +++ b/deep_quoridor/rust/src/compact/policy_db.rs @@ -4,15 +4,16 @@ /// with a transposition table. /// /// Storage: Parquet file with two columns (`state`, `value`), sorted by -/// `state`. The footer holds key/value metadata (board_size, max_walls, -/// max_steps, num_states). Row-group min/max statistics on `state` are -/// used to prune lookups; we never load the entire DB into memory. +/// `state` bytes (lex order). The footer holds key/value metadata +/// (board_size, max_walls, max_steps, num_states). Row-group min/max +/// statistics on `state` are used to prune lookups; we never load the +/// entire DB into memory. use std::collections::HashMap; use std::fs::File; use std::path::Path; use std::sync::Arc; -use arrow::array::{Array, Int64Array, Int8Array, RecordBatch}; +use arrow::array::{Array, FixedSizeBinaryArray, Int8Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use dashmap::DashMap; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ParquetRecordBatchReaderBuilder}; @@ -25,10 +26,14 @@ use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::SeedableRng; +use super::q_bit_repr::CompactState; use super::q_game_mechanics::QGameMechanics; +/// Width of the on-disk / in-memory byte representation of a state. +pub const STATE_BYTES: usize = 24; + /// Transposition table type alias. -pub type TranspositionTable = DashMap; +pub type TranspositionTable = DashMap; /// Number of rows per Parquet row group. Larger means fewer groups (less /// per-group overhead) but larger decode units. 1M is a good general default. @@ -40,22 +45,20 @@ const ROW_GROUP_SIZE: usize = 1_000_000; const WRITE_CHUNK_SIZE: usize = 65_536; /// Per-row-group statistics cached at open time, used to prune lookups. -/// State values are stored in the Parquet file as `Int64` (the same 64 -/// bits as the underlying `u64`). All comparisons here are signed `i64`, -/// matching the Parquet sort/statistics ordering. Callers convert to/from -/// `u64` via `as` casts at the boundary. +/// State values are stored as `FixedSizeBinary(24)`; comparisons use +/// lexicographic byte order, matching the Parquet sort/statistics ordering. #[derive(Clone, Debug)] struct RowGroupStats { idx: usize, - min: i64, - max: i64, + min: [u8; STATE_BYTES], + max: [u8; STATE_BYTES], num_rows: u64, } /// Decoded contents of one row group, kept around when we want to reuse /// it across nearby lookups in the same call. struct DecodedRowGroup { - states: Vec, + states: Vec, values: Vec, } @@ -72,15 +75,15 @@ enum Storage { cum_rows: Vec, }, Eager { - /// File order (= sorted by state-as-i64). Indexable by `rowid - 1`. - ordered: Vec<(u64, i8)>, + /// File order (= sorted by state bytes lex). Indexable by `rowid - 1`. + ordered: Vec<(CompactState, i8)>, /// O(1) point lookups for `lookup_values_by_state`. - table: HashMap, + table: HashMap, }, } /// Parquet-backed policy database for storing and querying pre-computed -/// minimax values. Single file per DB; sorted by `state` for efficient +/// minimax values. Single file per DB; sorted by `state` bytes for efficient /// row-group pruning on point lookups. pub struct PolicyDb { storage: Storage, @@ -93,7 +96,11 @@ pub struct PolicyDb { fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ - Field::new("state", DataType::Int64, false), + Field::new( + "state", + DataType::FixedSizeBinary(STATE_BYTES as i32), + false, + ), Field::new("value", DataType::Int8, false), ])) } @@ -105,11 +112,29 @@ fn parse_meta(kv: Option<&Vec>, key: &str) -> Option { .and_then(|v| v.parse().ok()) } +fn bytes_to_state(b: &[u8]) -> Result> { + if b.len() != STATE_BYTES { + return Err(format!("state column row width is {}, expected {STATE_BYTES}", b.len()).into()); + } + let mut a = [0u8; STATE_BYTES]; + a.copy_from_slice(b); + Ok(CompactState::from_bytes(a)) +} + +fn slice_to_arr(b: &[u8]) -> Result<[u8; STATE_BYTES], Box> { + if b.len() != STATE_BYTES { + return Err(format!("stat byte width is {}, expected {STATE_BYTES}", b.len()).into()); + } + let mut a = [0u8; STATE_BYTES]; + a.copy_from_slice(b); + Ok(a) +} + impl PolicyDb { /// Open an existing policy database for reading. /// /// When `lazy` is `false` (the default for callers), the entire dataset - /// is decoded into a `HashMap` at open time so subsequent + /// is decoded into a `HashMap` at open time so subsequent /// state lookups are O(1). When `lazy` is `true`, lookups walk Parquet /// row groups on demand using the file's min/max statistics; useful for /// DBs too large to fit in memory. @@ -142,9 +167,17 @@ impl PolicyDb { .statistics() .ok_or_else(|| format!("row group {i} missing statistics on state column"))?; let (min, max) = match stats { - Statistics::Int64(s) => ( - *s.min_opt().ok_or("missing min stat on state")?, - *s.max_opt().ok_or("missing max stat on state")?, + Statistics::FixedLenByteArray(s) => ( + slice_to_arr( + s.min_opt() + .ok_or("missing min stat on state")? + .as_ref(), + )?, + slice_to_arr( + s.max_opt() + .ok_or("missing max stat on state")? + .as_ref(), + )?, ), _ => { return Err( @@ -180,16 +213,16 @@ impl PolicyDb { let reader = ParquetRecordBatchReaderBuilder::new_with_metadata(file, metadata.clone()) .build()?; - let mut state_buf: Vec = Vec::with_capacity(num_states); + let mut state_buf: Vec = Vec::with_capacity(num_states); let mut value_buf: Vec = Vec::with_capacity(num_states); for batch in reader { let batch = batch?; append_batch(&batch, &mut state_buf, &mut value_buf)?; } - let mut ordered: Vec<(u64, i8)> = Vec::with_capacity(state_buf.len()); + let mut ordered: Vec<(CompactState, i8)> = Vec::with_capacity(state_buf.len()); for (s, v) in state_buf.into_iter().zip(value_buf.into_iter()) { - ordered.push((s as u64, v)); + ordered.push((s, v)); } let mut table = HashMap::with_capacity(ordered.len()); for &(s, v) in &ordered { @@ -259,7 +292,7 @@ impl PolicyDb { /// Values are returned from the acting player's perspective. pub fn lookup_action_values( &self, - data: u64, + data: CompactState, ) -> Result, Vec)>, Box> { let mechanics = &self.mechanics; let cp = mechanics.repr().get_current_player(data); @@ -315,7 +348,7 @@ impl PolicyDb { } // Batch-look up all non-terminal children in one sorted sweep. - let need_lookup: Vec = child_states + let need_lookup: Vec = child_states .iter() .zip(terminal_p0.iter()) .filter(|(_, t)| t.is_none()) @@ -323,7 +356,7 @@ impl PolicyDb { .collect(); let lookup_pairs = self.lookup_values_by_state(&need_lookup)?; - let lookup_map: HashMap = lookup_pairs.into_iter().collect(); + let lookup_map: HashMap = lookup_pairs.into_iter().collect(); let mut values = Vec::with_capacity(actions.len()); let mut any_found = false; @@ -360,7 +393,7 @@ impl PolicyDb { pub fn fetch_states_by_rowid( &self, rowids: &[i64], - ) -> Result, Box> { + ) -> Result, Box> { if rowids.is_empty() { return Ok(Vec::new()); } @@ -425,7 +458,7 @@ impl PolicyDb { current_rg = Some(rg_idx); } let dec = decoded.as_ref().unwrap(); - let s = dec.states[row_in_group] as u64; + let s = dec.states[row_in_group]; let v = dec.values[row_in_group] as i32; results.push((s, v)); } @@ -443,8 +476,8 @@ impl PolicyDb { /// decoding each at most once. pub fn lookup_values_by_state( &self, - states: &[u64], - ) -> Result, Box> { + states: &[CompactState], + ) -> Result, Box> { if states.is_empty() { return Ok(Vec::new()); } @@ -460,7 +493,8 @@ impl PolicyDb { row_groups, .. } => { - let mut sorted: Vec = states.iter().map(|&s| s as i64).collect(); + let mut sorted: Vec<[u8; STATE_BYTES]> = + states.iter().map(|s| s.to_bytes()).collect(); sorted.sort_unstable(); sorted.dedup(); @@ -481,11 +515,17 @@ impl PolicyDb { // q is in [min, max]; decode this row group once and // resolve every query whose value falls in the range. let decoded = Self::decode_row_group(path, metadata, rg)?; + // Build a sorted bytes view of the decoded states for binary search. + let decoded_bytes: Vec<[u8; STATE_BYTES]> = + decoded.states.iter().map(|s| s.to_bytes()).collect(); while q_idx < sorted.len() && sorted[q_idx] <= rg.max { let q2 = sorted[q_idx]; if q2 >= rg.min { - if let Ok(pos) = decoded.states.binary_search(&q2) { - results.push((q2 as u64, decoded.values[pos] as i32)); + if let Ok(pos) = decoded_bytes.binary_search(&q2) { + results.push(( + CompactState::from_bytes(q2), + decoded.values[pos] as i32, + )); } } q_idx += 1; @@ -512,15 +552,14 @@ impl PolicyDb { max_walls: usize, step_interval: usize, ) -> Result> { - // Drain DashMap, apply step_interval filter, sort by state. - // Sort uses i64 so the on-disk order matches the read-time - // statistics ordering (Parquet Int64 stats are signed). - let mut rows: Vec<(i64, i8)> = entries + // Drain DashMap, apply step_interval filter, sort by state bytes + // (matches Parquet FixedSizeBinary statistics ordering). + let mut rows: Vec<([u8; STATE_BYTES], i8)> = entries .into_iter() .filter_map(|(s, v)| { let steps = mechanics.repr().get_completed_steps(s); if steps % step_interval == 0 { - Some((s as i64, v)) + Some((s.to_bytes(), v)) } else { None } @@ -560,14 +599,13 @@ impl PolicyDb { // Stream rows in moderate chunks so we don't allocate one giant batch. for chunk in rows.chunks(WRITE_CHUNK_SIZE) { - let states: Vec = chunk.iter().map(|(s, _)| *s).collect(); + let states_arr = Arc::new(FixedSizeBinaryArray::try_from_iter( + chunk.iter().map(|(s, _)| s.to_vec()), + )?); let values: Vec = chunk.iter().map(|(_, v)| *v).collect(); let batch = RecordBatch::try_new( schema.clone(), - vec![ - Arc::new(Int64Array::from(states)), - Arc::new(Int8Array::from(values)), - ], + vec![states_arr, Arc::new(Int8Array::from(values))], )?; writer.write(&batch)?; } @@ -580,26 +618,39 @@ impl PolicyDb { /// Append a record batch's two columns to the running state/value vectors. fn append_batch( batch: &RecordBatch, - states: &mut Vec, + states: &mut Vec, values: &mut Vec, ) -> Result<(), Box> { let s_col = batch .column(0) .as_any() - .downcast_ref::() - .ok_or("state column is not Int64")?; + .downcast_ref::() + .ok_or("state column is not FixedSizeBinary")?; let v_col = batch .column(1) .as_any() .downcast_ref::() .ok_or("value column is not Int8")?; - states.extend_from_slice(s_col.values()); + if s_col.value_length() != STATE_BYTES as i32 { + return Err(format!( + "state column width {} != expected {}", + s_col.value_length(), + STATE_BYTES + ) + .into()); + } + for i in 0..s_col.len() { + states.push(bytes_to_state(s_col.value(i))?); + } values.extend_from_slice(v_col.values()); Ok(()) } /// Get all valid actions (moves + wall placements) for the current player. -fn get_all_actions(mechanics: &QGameMechanics, data: &mut u64) -> Vec<(u8, u8, u8)> { +fn get_all_actions( + mechanics: &QGameMechanics, + data: &mut CompactState, +) -> Vec<(u8, u8, u8)> { let moves = mechanics.get_valid_moves(*data); let mut actions: Vec<(u8, u8, u8)> = moves .into_iter() @@ -625,7 +676,7 @@ fn get_all_actions(mechanics: &QGameMechanics, data: &mut u64) -> Vec<(u8, u8, u /// transposition table for later export to a policy database. pub fn minimax( mechanics: &QGameMechanics, - data: &mut u64, + data: &mut CompactState, transposition_table: &TranspositionTable, ) -> i8 { minimax_inner(mechanics, data, transposition_table, None) @@ -633,7 +684,7 @@ pub fn minimax( fn minimax_inner( mechanics: &QGameMechanics, - data: &mut u64, + data: &mut CompactState, transposition_table: &TranspositionTable, mut rng: Option<&mut StdRng>, ) -> i8 { @@ -716,7 +767,7 @@ fn minimax_inner( /// transposition table. Returns the root value. pub fn minimax_lazy_smp( mechanics: &QGameMechanics, - data: &mut u64, + data: &mut CompactState, transposition_table: &TranspositionTable, num_threads: usize, ) -> i8 { @@ -807,7 +858,8 @@ mod tests { assert!(!table.is_empty()); // Snapshot expected entries before write() drains the DashMap. - let expected: Vec<(u64, i8)> = table.iter().map(|kv| (*kv.key(), *kv.value())).collect(); + let expected: Vec<(CompactState, i8)> = + table.iter().map(|kv| (*kv.key(), *kv.value())).collect(); let dir = tempdir().unwrap(); let path = dir.path().join("test.parquet"); @@ -826,10 +878,10 @@ mod tests { assert_eq!(db.count_states().unwrap(), expected.len()); // Round-trip every entry by state lookup. - let states: Vec = expected.iter().map(|(s, _)| *s).collect(); + let states: Vec = expected.iter().map(|(s, _)| *s).collect(); let pairs = db.lookup_values_by_state(&states).unwrap(); assert_eq!(pairs.len(), expected.len()); - let got_map: HashMap = pairs.into_iter().collect(); + let got_map: HashMap = pairs.into_iter().collect(); for (s, v) in &expected { assert_eq!(got_map.get(s), Some(&(*v as i32))); } @@ -838,7 +890,7 @@ mod tests { let all_ids: Vec = (1..=expected.len() as i64).collect(); let by_id = db.fetch_states_by_rowid(&all_ids).unwrap(); assert_eq!(by_id.len(), expected.len()); - let by_id_map: HashMap = by_id.into_iter().collect(); + let by_id_map: HashMap = by_id.into_iter().collect(); for (s, v) in &expected { assert_eq!(by_id_map.get(s), Some(&(*v as i32))); } @@ -888,7 +940,8 @@ mod tests { let mut root = mechanics.create_initial_state(); let table = TranspositionTable::new(); minimax(&mechanics, &mut root, &table); - let expected: Vec<(u64, i8)> = table.iter().map(|kv| (*kv.key(), *kv.value())).collect(); + let expected: Vec<(CompactState, i8)> = + table.iter().map(|kv| (*kv.key(), *kv.value())).collect(); let dir = tempdir().unwrap(); let path = dir.path().join("test.parquet"); @@ -907,11 +960,11 @@ mod tests { // Same lookup_values_by_state results (sort both since order isn't // guaranteed across modes). - let states: Vec = expected.iter().map(|(s, _)| *s).collect(); + let states: Vec = expected.iter().map(|(s, _)| *s).collect(); let mut e_pairs = eager.lookup_values_by_state(&states).unwrap(); let mut l_pairs = lazy.lookup_values_by_state(&states).unwrap(); - e_pairs.sort_unstable_by_key(|(s, _)| *s); - l_pairs.sort_unstable_by_key(|(s, _)| *s); + e_pairs.sort_unstable_by_key(|(s, _)| s.to_bytes()); + l_pairs.sort_unstable_by_key(|(s, _)| s.to_bytes()); assert_eq!(e_pairs, l_pairs); assert_eq!(e_pairs.len(), expected.len()); diff --git a/deep_quoridor/rust/src/compact/q_bit_repr.rs b/deep_quoridor/rust/src/compact/q_bit_repr.rs index a0435539..f7515fc7 100644 --- a/deep_quoridor/rust/src/compact/q_bit_repr.rs +++ b/deep_quoridor/rust/src/compact/q_bit_repr.rs @@ -1,9 +1,11 @@ /// Compact bit-packed representation accessor for game states. /// -/// This struct doesn't store the game state data itself - it only stores the -/// parameters and computed offsets needed to interpret a u64 as a packed game state. -/// The data is passed to each method by value (read) or mutable reference (write). -/// Boards requiring more than 64 bits are not supported. +/// The state is split into two primitives so each field lives in one word: +/// - `walls: u128` — wall bitmap, supports up to 128 wall positions (9x9 board). +/// - `scalars: u64` — positions, walls-remaining, current player, completed steps. +/// +/// `QBitRepr` itself stores no game state — it only stores parameters and computed +/// offsets needed to interpret a `CompactState`. // Wall orientations pub const WALL_VERTICAL: usize = 0; @@ -18,6 +20,40 @@ const fn bits_needed(max: usize) -> usize { } } +/// Bit-packed Quoridor game state. +/// +/// Layout: +/// `walls` — bit `i` set iff wall at `wall_index_to_position(i)` is placed. +/// `scalars` — packed: p1_pos | p2_pos | p1_walls_remaining | p2_walls_remaining +/// | current_player | completed_steps +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)] +pub struct CompactState { + pub walls: u128, + pub scalars: u64, +} + +impl CompactState { + /// Serialize as 24 little-endian bytes: walls (16) ++ scalars (8). + pub fn to_bytes(self) -> [u8; 24] { + let mut out = [0u8; 24]; + out[..16].copy_from_slice(&self.walls.to_le_bytes()); + out[16..].copy_from_slice(&self.scalars.to_le_bytes()); + out + } + + /// Deserialize from 24 little-endian bytes. + pub fn from_bytes(b: [u8; 24]) -> Self { + let mut walls = [0u8; 16]; + walls.copy_from_slice(&b[..16]); + let mut scalars = [0u8; 8]; + scalars.copy_from_slice(&b[16..]); + Self { + walls: u128::from_le_bytes(walls), + scalars: u64::from_le_bytes(scalars), + } + } +} + /// Accessor for bit-packed game states. /// Stores parameters and computed offsets, but not the actual game state data. #[derive(Clone, Debug)] @@ -33,11 +69,10 @@ pub struct QBitRepr { position_bits: usize, walls_remaining_bits: usize, completed_steps_bits: usize, - total_bits: usize, + total_scalar_bits: usize, - // Bit offsets for each field - walls_offset: usize, - player_pos_offsets: [usize; 2], // Offset for each player's position + // Bit offsets within `scalars` + player_pos_offsets: [usize; 2], walls_remaining_offsets: [usize; 2], current_player_offset: usize, completed_steps_offset: usize, @@ -52,8 +87,12 @@ impl QBitRepr { let walls_remaining_bits = bits_needed(max_walls); let completed_steps_bits = bits_needed(max_steps); - let walls_offset = 0; - let p1_pos_offset = walls_offset + num_wall_positions; + assert!( + num_wall_positions <= 128, + "QBitRepr: wall bitmap requires {num_wall_positions} bits, which exceeds u128 capacity" + ); + + let p1_pos_offset = 0; let p2_pos_offset = p1_pos_offset + position_bits; let player_pos_offsets = [p1_pos_offset, p2_pos_offset]; let p1_walls_remaining_offset = p2_pos_offset + position_bits; @@ -61,11 +100,11 @@ impl QBitRepr { let walls_remaining_offsets = [p1_walls_remaining_offset, p2_walls_remaining_offset]; let current_player_offset = p2_walls_remaining_offset + walls_remaining_bits; let completed_steps_offset = current_player_offset + 1; + let total_scalar_bits = completed_steps_offset + completed_steps_bits; - let total_bits = completed_steps_offset + completed_steps_bits; assert!( - total_bits <= 64, - "QBitRepr: state requires {total_bits} bits, which exceeds u64 capacity" + total_scalar_bits <= 64, + "QBitRepr: scalar fields require {total_scalar_bits} bits, which exceeds u64 capacity" ); Self { @@ -77,8 +116,7 @@ impl QBitRepr { position_bits, walls_remaining_bits, completed_steps_bits, - total_bits, - walls_offset, + total_scalar_bits, player_pos_offsets, walls_remaining_offsets, current_player_offset, @@ -86,9 +124,10 @@ impl QBitRepr { } } + /// Total bits used across walls and scalars (for inspection / tests). #[allow(dead_code)] pub fn size_bits(&self) -> usize { - self.total_bits + self.num_wall_positions + self.total_scalar_bits } /// Get the number of wall positions @@ -103,96 +142,128 @@ impl QBitRepr { } /// Create a zero-initialized packed state value. - pub fn create_data(&self) -> u64 { - 0u64 + pub fn create_data(&self) -> CompactState { + CompactState::default() } - /// Get a bit at the specified position in the data + /// Read a bit from the wall bitmap. #[inline] - fn get_bit(&self, data: u64, bit_index: usize) -> bool { - debug_assert!(bit_index < self.total_bits); - (data >> bit_index) & 1 == 1 + fn get_wall_bit(&self, state: CompactState, wall_idx: usize) -> bool { + debug_assert!(wall_idx < self.num_wall_positions); + (state.walls >> wall_idx) & 1 == 1 } - /// Set a bit at the specified position in the data + /// Write a bit to the wall bitmap. #[inline] - fn set_bit(&self, data: &mut u64, bit_index: usize, value: bool) { - debug_assert!(bit_index < self.total_bits); + fn set_wall_bit(&self, state: &mut CompactState, wall_idx: usize, value: bool) { + debug_assert!(wall_idx < self.num_wall_positions); if value { - *data |= 1u64 << bit_index; + state.walls |= 1u128 << wall_idx; } else { - *data &= !(1u64 << bit_index); + state.walls &= !(1u128 << wall_idx); } } - /// Get an integer value starting at bit_offset with num_bits bits + /// Read a multi-bit integer from the scalar field. #[inline] - fn get_bits(&self, data: u64, bit_offset: usize, num_bits: usize) -> usize { - ((data >> bit_offset) & ((1u64 << num_bits) - 1)) as usize + fn get_scalar_bits(&self, state: CompactState, bit_offset: usize, num_bits: usize) -> usize { + debug_assert!(bit_offset + num_bits <= self.total_scalar_bits); + ((state.scalars >> bit_offset) & ((1u64 << num_bits) - 1)) as usize } - /// Set an integer value starting at bit_offset with num_bits bits + /// Write a multi-bit integer to the scalar field. #[inline] - fn set_bits(&self, data: &mut u64, bit_offset: usize, num_bits: usize, value: usize) { + fn set_scalar_bits( + &self, + state: &mut CompactState, + bit_offset: usize, + num_bits: usize, + value: usize, + ) { + debug_assert!(bit_offset + num_bits <= self.total_scalar_bits); let mask = (1u64 << num_bits) - 1; - *data = (*data & !(mask << bit_offset)) | ((value as u64 & mask) << bit_offset); + state.scalars = (state.scalars & !(mask << bit_offset)) | ((value as u64 & mask) << bit_offset); + } + + /// Read a single scalar bit. + #[inline] + fn get_scalar_bit(&self, state: CompactState, bit_offset: usize) -> bool { + debug_assert!(bit_offset < self.total_scalar_bits); + (state.scalars >> bit_offset) & 1 == 1 + } + + /// Write a single scalar bit. + #[inline] + fn set_scalar_bit(&self, state: &mut CompactState, bit_offset: usize, value: bool) { + debug_assert!(bit_offset < self.total_scalar_bits); + if value { + state.scalars |= 1u64 << bit_offset; + } else { + state.scalars &= !(1u64 << bit_offset); + } } /// Check if a wall is present at the given position - pub fn get_wall(&self, data: u64, row: usize, col: usize, orientation: usize) -> bool { + pub fn get_wall(&self, state: CompactState, row: usize, col: usize, orientation: usize) -> bool { let wall_index = self.wall_position_to_index(row, col, orientation); - self.get_bit(data, self.walls_offset + wall_index) + self.get_wall_bit(state, wall_index) } /// Set a wall at the given position pub fn set_wall( &self, - data: &mut u64, + state: &mut CompactState, row: usize, col: usize, orientation: usize, present: bool, ) { let wall_index = self.wall_position_to_index(row, col, orientation); - self.set_bit(data, self.walls_offset + wall_index, present); + self.set_wall_bit(state, wall_index, present); } /// Get a player's position - pub fn get_player_position(&self, data: u64, player: usize) -> (usize, usize) { + pub fn get_player_position(&self, state: CompactState, player: usize) -> (usize, usize) { debug_assert!(player < 2); - let index = self.get_bits(data, self.player_pos_offsets[player], self.position_bits); + let index = self.get_scalar_bits(state, self.player_pos_offsets[player], self.position_bits); self.index_to_position(index) } /// Set a player's position - pub fn set_player_position(&self, data: &mut u64, player: usize, row: usize, col: usize) { + pub fn set_player_position( + &self, + state: &mut CompactState, + player: usize, + row: usize, + col: usize, + ) { debug_assert!(player < 2); let new_index = self.position_to_index(row, col); debug_assert!(new_index < self.num_player_positions); - self.set_bits( - data, + self.set_scalar_bits( + state, self.player_pos_offsets[player], self.position_bits, new_index, ); } - /// Get player 1's remaining walls - pub fn get_walls_remaining(&self, data: u64, player: usize) -> usize { + /// Get a player's remaining walls + pub fn get_walls_remaining(&self, state: CompactState, player: usize) -> usize { debug_assert!(player < 2); - self.get_bits( - data, + self.get_scalar_bits( + state, self.walls_remaining_offsets[player], self.walls_remaining_bits, ) } - /// Set player 1's remaining walls - pub fn set_walls_remaining(&self, data: &mut u64, player: usize, walls: usize) { + /// Set a player's remaining walls + pub fn set_walls_remaining(&self, state: &mut CompactState, player: usize, walls: usize) { debug_assert!(player < 2); debug_assert!(walls <= self.max_walls); - self.set_bits( - data, + self.set_scalar_bits( + state, self.walls_remaining_offsets[player], self.walls_remaining_bits, walls, @@ -200,8 +271,8 @@ impl QBitRepr { } /// Get the current player (0 or 1) - pub fn get_current_player(&self, data: u64) -> usize { - if self.get_bit(data, self.current_player_offset) { + pub fn get_current_player(&self, state: CompactState) -> usize { + if self.get_scalar_bit(state, self.current_player_offset) { 1 } else { 0 @@ -209,21 +280,21 @@ impl QBitRepr { } /// Set the current player - pub fn set_current_player(&self, data: &mut u64, player: usize) { + pub fn set_current_player(&self, state: &mut CompactState, player: usize) { debug_assert!(player < 2); - self.set_bit(data, self.current_player_offset, player == 1); + self.set_scalar_bit(state, self.current_player_offset, player == 1); } /// Get the number of completed steps - pub fn get_completed_steps(&self, data: u64) -> usize { - self.get_bits(data, self.completed_steps_offset, self.completed_steps_bits) + pub fn get_completed_steps(&self, state: CompactState) -> usize { + self.get_scalar_bits(state, self.completed_steps_offset, self.completed_steps_bits) } /// Set the number of completed steps - pub fn set_completed_steps(&self, data: &mut u64, steps: usize) { + pub fn set_completed_steps(&self, state: &mut CompactState, steps: usize) { debug_assert!(steps <= self.max_steps); - self.set_bits( - data, + self.set_scalar_bits( + state, self.completed_steps_offset, self.completed_steps_bits, steps, @@ -283,8 +354,8 @@ impl QBitRepr { } /// Display the board state as text art - pub fn print(&self, data: u64) { - println!("{}", self.display(data)); + pub fn print(&self, state: CompactState) { + println!("{}", self.display(state)); } /// Create string with text art of the board @@ -295,16 +366,16 @@ impl QBitRepr { /// - '|' for vertical walls /// - '-' for horizontal walls /// - Metadata (steps, walls) shown on the right side - pub fn display(&self, data: u64) -> String { + pub fn display(&self, state: CompactState) -> String { let mut output = String::new(); // Get metadata - let (p0_row, p0_col) = self.get_player_position(data, 0); - let (p1_row, p1_col) = self.get_player_position(data, 1); - let p0_walls = self.get_walls_remaining(data, 0); - let p1_walls = self.get_walls_remaining(data, 1); - let current_player = self.get_current_player(data); - let steps = self.get_completed_steps(data); + let (p0_row, p0_col) = self.get_player_position(state, 0); + let (p1_row, p1_col) = self.get_player_position(state, 1); + let p0_walls = self.get_walls_remaining(state, 0); + let p1_walls = self.get_walls_remaining(state, 1); + let current_player = self.get_current_player(state); + let steps = self.get_completed_steps(state); // Build metadata lines let meta_lines = vec![ @@ -332,9 +403,9 @@ impl QBitRepr { // Print vertical wall to the right (if not last column) if col < self.board_size - 1 { - if row < self.board_size - 1 && self.get_wall(data, row, col, WALL_VERTICAL) { + if row < self.board_size - 1 && self.get_wall(state, row, col, WALL_VERTICAL) { cell_line.push('|'); - } else if row > 0 && self.get_wall(data, row - 1, col, WALL_VERTICAL) { + } else if row > 0 && self.get_wall(state, row - 1, col, WALL_VERTICAL) { cell_line.push('|'); } else { cell_line.push(' '); @@ -355,9 +426,9 @@ impl QBitRepr { let mut wall_line = String::new(); for col in 0..self.board_size { // Print horizontal wall below this cell - if col < self.board_size - 1 && self.get_wall(data, row, col, WALL_HORIZONTAL) { + if col < self.board_size - 1 && self.get_wall(state, row, col, WALL_HORIZONTAL) { wall_line.push('-'); - } else if col > 0 && self.get_wall(data, row, col - 1, WALL_HORIZONTAL) { + } else if col > 0 && self.get_wall(state, row, col - 1, WALL_HORIZONTAL) { wall_line.push('-'); } else { wall_line.push(' '); @@ -390,15 +461,29 @@ mod tests { #[test] fn test_size_calculation() { let q = QBitRepr::new(5, 10, 100); - // 5x5 board: 2*(5-1)*(5-1) = 2*4*4 = 32 wall bits + // 5x5 board: 2*(5-1)*(5-1) = 32 wall bits // Positions: ceil(log2(25)) = 5 bits each, 2 players = 10 bits // Walls remaining: ceil(log2(10)) = 4 bits each = 8 bits // Current player: 1 bit // Steps: ceil(log2(100)) = 7 bits - // Total: 32 + 10 + 8 + 1 + 7 = 58 bits + // Scalars total: 10 + 8 + 1 + 7 = 26 bits + // Walls + scalars: 32 + 26 = 58 bits assert_eq!(q.size_bits(), 58); } + #[test] + fn test_size_calculation_9x9_10w() { + let q = QBitRepr::new(9, 10, 200); + // 9x9 board: 2*8*8 = 128 wall bits + // Positions: ceil(log2(81)) = 7 bits each, 2 players = 14 bits + // Walls remaining: ceil(log2(10)) = 4 bits each = 8 bits + // Current player: 1 bit + // Steps: ceil(log2(200)) = 8 bits + // Scalars total: 14 + 8 + 1 + 8 = 31 bits + assert_eq!(q.num_wall_positions(), 128); + assert_eq!(q.size_bits(), 128 + 31); + } + #[test] fn test_player_positions() { let q = QBitRepr::new(5, 10, 100); @@ -425,6 +510,22 @@ mod tests { assert!(!q.get_wall(data, 0, 1, WALL_VERTICAL)); } + #[test] + fn test_walls_9x9() { + let q = QBitRepr::new(9, 10, 200); + let mut data = q.create_data(); + + // Walls at all four corners of the wall grid, in both orientations. + for &(r, c) in &[(0usize, 0usize), (0, 7), (7, 0), (7, 7)] { + for &orient in &[WALL_VERTICAL, WALL_HORIZONTAL] { + q.set_wall(&mut data, r, c, orient, true); + assert!(q.get_wall(data, r, c, orient)); + } + } + // The wall bitmap should occupy the high bits too (bit 127 = last horizontal corner). + assert_ne!(data.walls >> 64, 0); + } + #[test] fn test_walls_remaining() { let q = QBitRepr::new(5, 10, 100); @@ -434,6 +535,16 @@ mod tests { assert_eq!(q.get_walls_remaining(data, 0), 10); } + #[test] + fn test_walls_remaining_9x9() { + let q = QBitRepr::new(9, 10, 200); + let mut data = q.create_data(); + q.set_walls_remaining(&mut data, 0, 10); + q.set_walls_remaining(&mut data, 1, 7); + assert_eq!(q.get_walls_remaining(data, 0), 10); + assert_eq!(q.get_walls_remaining(data, 1), 7); + } + #[test] fn test_current_player() { let q = QBitRepr::new(5, 10, 100); @@ -473,4 +584,21 @@ mod tests { let (row, col, orientation) = q.wall_index_to_position(idx); assert_eq!((row, col, orientation), (3, 1, WALL_HORIZONTAL)); } + + #[test] + fn test_compact_state_bytes_roundtrip() { + let q = QBitRepr::new(9, 10, 200); + let mut state = q.create_data(); + q.set_player_position(&mut state, 0, 0, 4); + q.set_player_position(&mut state, 1, 8, 4); + q.set_walls_remaining(&mut state, 0, 10); + q.set_walls_remaining(&mut state, 1, 10); + q.set_current_player(&mut state, 1); + q.set_completed_steps(&mut state, 137); + q.set_wall(&mut state, 7, 7, WALL_HORIZONTAL, true); + + let bytes = state.to_bytes(); + let restored = CompactState::from_bytes(bytes); + assert_eq!(state, restored); + } } diff --git a/deep_quoridor/rust/src/compact/q_bit_repr_conversions.rs b/deep_quoridor/rust/src/compact/q_bit_repr_conversions.rs index 86af90d8..5020c665 100644 --- a/deep_quoridor/rust/src/compact/q_bit_repr_conversions.rs +++ b/deep_quoridor/rust/src/compact/q_bit_repr_conversions.rs @@ -1,4 +1,4 @@ -use super::q_bit_repr::QBitRepr; +use super::q_bit_repr::{CompactState, QBitRepr}; /// Conversion functions between QBitRepr packed format and game state arrays. /// /// This module contains methods for converting between the bit-packed representation @@ -17,7 +17,7 @@ impl QBitRepr { /// * `completed_steps` - Number of steps completed so far pub fn from_game_state( &self, - data: &mut u64, + data: &mut CompactState, grid: &ArrayView2, player_positions: &ArrayView2, walls_remaining: &ArrayView1, @@ -71,7 +71,7 @@ impl QBitRepr { /// Extract player positions as a 2x2 array (format used by minimax) /// Returns [[p1_row, p1_col], [p2_row, p2_col]] #[allow(dead_code)] - pub fn to_player_positions(&self, data: u64) -> Array2 { + pub fn to_player_positions(&self, data: CompactState) -> Array2 { let (p1_row, p1_col) = self.get_player_position(data, 0); let (p2_row, p2_col) = self.get_player_position(data, 1); @@ -85,7 +85,7 @@ impl QBitRepr { /// Extract walls remaining as a 1D array (format used by minimax) /// Returns [p1_walls, p2_walls] #[allow(dead_code)] - pub fn to_walls_remaining(&self, data: u64) -> Array1 { + pub fn to_walls_remaining(&self, data: CompactState) -> Array1 { Array1::from_vec(vec![ self.get_walls_remaining(data, 0) as i32, self.get_walls_remaining(data, 1) as i32, @@ -95,7 +95,7 @@ impl QBitRepr { /// Reconstruct the full grid with walls and player positions /// This creates a grid in the format used by minimax: (2*board_size + 3) x (2*board_size + 3) #[allow(dead_code)] - pub fn to_grid(&self, data: u64) -> Array2 { + pub fn to_grid(&self, data: CompactState) -> Array2 { use crate::grid::{CELL_FREE, CELL_PLAYER1, CELL_PLAYER2, CELL_WALL}; let grid_size = 2 * self.board_size() + 3; diff --git a/deep_quoridor/rust/src/compact/q_game_mechanics.rs b/deep_quoridor/rust/src/compact/q_game_mechanics.rs index 8a4be1a3..4d19d603 100644 --- a/deep_quoridor/rust/src/compact/q_game_mechanics.rs +++ b/deep_quoridor/rust/src/compact/q_game_mechanics.rs @@ -4,7 +4,7 @@ /// the bit-packed representation instead of converting to/from grid arrays. use std::collections::VecDeque; -use super::q_bit_repr::{QBitRepr, WALL_HORIZONTAL, WALL_VERTICAL}; +use super::q_bit_repr::{CompactState, QBitRepr, WALL_HORIZONTAL, WALL_VERTICAL}; /// Reusable buffer for BFS operations, avoiding repeated heap allocations. pub struct BfsBuffer { @@ -61,7 +61,7 @@ impl QGameMechanics { /// Create initial game state #[allow(dead_code)] - pub fn create_initial_state(&self) -> u64 { + pub fn create_initial_state(&self) -> CompactState { let mut data = self.repr.create_data(); let board_size = self.repr.board_size(); @@ -90,7 +90,7 @@ impl QGameMechanics { /// Doesn't check that players can reach their goal still. pub fn is_wall_placement_free( &self, - data: u64, + data: CompactState, row: usize, col: usize, orientation: usize, @@ -149,21 +149,21 @@ impl QGameMechanics { /// Place a wall (no validation - use is_wall_placement_valid first) #[inline] - pub fn place_wall(&self, data: &mut u64, row: usize, col: usize, orientation: usize) { + pub fn place_wall(&self, data: &mut CompactState, row: usize, col: usize, orientation: usize) { self.repr.set_wall(data, row, col, orientation, true); } /// Remove a wall #[inline] #[allow(dead_code)] - pub fn remove_wall(&self, data: &mut u64, row: usize, col: usize, orientation: usize) { + pub fn remove_wall(&self, data: &mut CompactState, row: usize, col: usize, orientation: usize) { self.repr.set_wall(data, row, col, orientation, false); } /// Check if there's a wall blocking movement between two adjacent cells fn is_wall_between( &self, - data: u64, + data: CompactState, from_row: usize, from_col: usize, to_row: usize, @@ -264,7 +264,7 @@ impl QGameMechanics { /// Check if a player can reach their goal row using BFS. /// Uses a reusable BfsBuffer to avoid allocations. - fn can_reach_goal(&self, data: u64, player: usize, buf: &mut BfsBuffer) -> bool { + fn can_reach_goal(&self, data: CompactState, player: usize, buf: &mut BfsBuffer) -> bool { let board_size = self.repr.board_size(); let goal_row = self.goal_rows[player]; @@ -328,7 +328,7 @@ impl QGameMechanics { /// Temporarily mutates `data` in-place (places then removes the wall) to avoid cloning. pub fn is_wall_placement_valid( &self, - data: &mut u64, + data: &mut CompactState, row: usize, col: usize, orientation: usize, @@ -348,7 +348,7 @@ impl QGameMechanics { } /// Execute a move action - pub fn execute_move(&self, data: &mut u64, player: usize, dest_row: usize, dest_col: usize) { + pub fn execute_move(&self, data: &mut CompactState, player: usize, dest_row: usize, dest_col: usize) { self.repr .set_player_position(data, player, dest_row, dest_col); } @@ -356,7 +356,7 @@ impl QGameMechanics { /// Execute a wall placement action pub fn execute_wall_placement( &self, - data: &mut u64, + data: &mut CompactState, player: usize, row: usize, col: usize, @@ -372,7 +372,7 @@ impl QGameMechanics { } /// Switch to the next player - pub fn switch_player(&self, data: &mut u64) { + pub fn switch_player(&self, data: &mut CompactState) { let current = self.repr.get_current_player(*data); self.repr.set_current_player(data, 1 - current); @@ -382,19 +382,19 @@ impl QGameMechanics { } /// Check if a player has won - pub fn check_win(&self, data: u64, player: usize) -> bool { + pub fn check_win(&self, data: CompactState, player: usize) -> bool { let (row, _col) = self.repr.get_player_position(data, player); row == self.goal_rows[player] } /// Check if the game is a draw (max steps reached) #[allow(dead_code)] - pub fn is_draw(&self, data: u64) -> bool { + pub fn is_draw(&self, data: CompactState) -> bool { self.repr.get_completed_steps(data) >= self.repr.max_steps() } /// Get all valid wall placements for the current player - pub fn get_valid_wall_placements(&self, data: &mut u64) -> Vec<(usize, usize, usize)> { + pub fn get_valid_wall_placements(&self, data: &mut CompactState) -> Vec<(usize, usize, usize)> { let current_player = self.repr.get_current_player(*data); // Check if player has walls remaining @@ -422,7 +422,7 @@ impl QGameMechanics { } /// Get all valid moves for the current player - pub fn get_valid_moves(&self, data: u64) -> Vec<(usize, usize)> { + pub fn get_valid_moves(&self, data: CompactState) -> Vec<(usize, usize)> { let current_player = self.repr.get_current_player(data); let board_size = self.repr.board_size(); @@ -581,11 +581,11 @@ impl QGameMechanics { } /// Display the board state as text art - pub fn print(&self, data: u64) { + pub fn print(&self, data: CompactState) { println!("{}", self.display(data)); } - pub fn display(&self, data: u64) -> String { + pub fn display(&self, data: CompactState) -> String { self.repr.display(data) } @@ -593,7 +593,7 @@ impl QGameMechanics { /// /// `data` is modified during the call (wall validation places-then-removes /// walls in place) but restored before return; observers see no change. - pub fn get_action_mask(&self, data: &mut u64) -> Vec { + pub fn get_action_mask(&self, data: &mut CompactState) -> Vec { let bs = self.repr.board_size(); let board_size_i = bs as i32; let total = crate::actions::policy_size(board_size_i); @@ -613,8 +613,8 @@ impl QGameMechanics { mask } - /// Immutable wrapper around `get_action_mask` for callers that hold `u64` by value. - pub fn get_action_mask_immut(&self, data: u64) -> Vec { + /// Immutable wrapper around `get_action_mask` for callers that hold the state by value. + pub fn get_action_mask_immut(&self, data: CompactState) -> Vec { let mut d = data; self.get_action_mask(&mut d) } @@ -623,7 +623,7 @@ impl QGameMechanics { /// /// Decodes the action, executes the appropriate move/wall placement, then /// switches to the next player (which also increments `completed_steps`). - pub fn apply_action_index(&self, data: &mut u64, action_idx: usize) { + pub fn apply_action_index(&self, data: &mut CompactState, action_idx: usize) { let bs = self.repr.board_size() as i32; let action = crate::actions::action_index_to_action(bs, action_idx); let (r, c, t) = (action[0] as usize, action[1] as usize, action[2]); @@ -644,12 +644,12 @@ impl QGameMechanics { } /// Returns true if either player has reached their goal row. - pub fn is_game_over(&self, data: u64) -> bool { + pub fn is_game_over(&self, data: CompactState) -> bool { self.check_win(data, 0) || self.check_win(data, 1) } /// Returns `Some(player)` if a player has won, `None` otherwise. - pub fn winner(&self, data: u64) -> Option { + pub fn winner(&self, data: CompactState) -> Option { if self.check_win(data, 0) { Some(0) } else if self.check_win(data, 1) { @@ -838,7 +838,7 @@ mod tests { board_str: &str, ) -> ( QGameMechanics, - u64, + CompactState, Vec<(usize, usize)>, Vec<(usize, usize, usize)>, ) { diff --git a/deep_quoridor/rust/src/compact/q_minimax.rs b/deep_quoridor/rust/src/compact/q_minimax.rs index 72a966c0..bdff5d0c 100644 --- a/deep_quoridor/rust/src/compact/q_minimax.rs +++ b/deep_quoridor/rust/src/compact/q_minimax.rs @@ -4,7 +4,7 @@ use rand::seq::SliceRandom; use rayon::prelude::*; use std::sync::Arc; -use super::q_bit_repr::{WALL_HORIZONTAL, WALL_VERTICAL}; +use super::q_bit_repr::{CompactState, WALL_HORIZONTAL, WALL_VERTICAL}; use super::q_game_mechanics::{BfsBuffer, QGameMechanics}; pub const WINNING_REWARD: f32 = 1e6; @@ -21,7 +21,7 @@ pub struct TranspositionEntry { } /// Compute distance to goal for a player using QBitRepr state -fn distance_to_goal(mechanics: &QGameMechanics, data: u64, player: usize) -> i32 { +fn distance_to_goal(mechanics: &QGameMechanics, data: CompactState, player: usize) -> i32 { let (row, col) = mechanics.repr().get_player_position(data, player); let board_size = mechanics.repr().board_size(); let goal_row = mechanics.get_goal_row(player); @@ -96,7 +96,7 @@ fn distance_to_goal(mechanics: &QGameMechanics, data: u64, player: usize) -> i32 /// Compute heuristic for QBitRepr state fn compute_heuristic( mechanics: &QGameMechanics, - data: u64, + data: CompactState, agent_player: usize, heuristic: i32, ) -> f32 { @@ -130,7 +130,7 @@ fn compute_heuristic( /// rather than validating all possible wall placements upfront. fn sample_actions( mechanics: &QGameMechanics, - data: &mut u64, + data: &mut CompactState, branching_factor: usize, ) -> Vec<(usize, usize, usize)> { let mut rng = rand::thread_rng(); @@ -184,7 +184,7 @@ fn sample_actions( #[allow(clippy::too_many_arguments)] fn minimax( mechanics: &QGameMechanics, - data: &mut u64, + data: &mut CompactState, current_player: usize, agent_player: usize, search_depth: usize, @@ -194,7 +194,7 @@ fn minimax( heuristic: i32, mut alpha: f32, mut beta: f32, - transposition_table: Arc>, + transposition_table: Arc>, ) -> f32 { // Check transposition table for cached result if let Some(entry) = transposition_table.get(data) { @@ -312,7 +312,7 @@ fn minimax( /// Evaluate actions using QBitRepr-based minimax (parallelized) pub fn evaluate_actions( mechanics: &QGameMechanics, - data: u64, + data: CompactState, max_search_depth: usize, branching_factor: usize, discount_factor: f32, @@ -320,7 +320,7 @@ pub fn evaluate_actions( ) -> ( Vec<(usize, usize, usize)>, Vec, - DashMap, + DashMap, ) { let current_player = mechanics.repr().get_current_player(data); let agent_player = current_player; diff --git a/deep_quoridor/rust/src/game_runner.rs b/deep_quoridor/rust/src/game_runner.rs index ab073709..c4d9b3fd 100644 --- a/deep_quoridor/rust/src/game_runner.rs +++ b/deep_quoridor/rust/src/game_runner.rs @@ -11,6 +11,7 @@ use crate::actions::{ action_index_to_action, ACTION_MOVE, ACTION_WALL_HORIZONTAL, ACTION_WALL_VERTICAL, }; use crate::agents::ActionSelector; +use crate::compact::q_bit_repr::CompactState; use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; use crate::grid_helpers::compact_state_to_resnet_input; @@ -38,7 +39,7 @@ fn format_action(_board_size: i32, row: i32, col: i32, action_type: i32) -> Stri } } -/// Build a transient `GameState` from a compact `u64` state. +/// Build a transient `GameState` from a compact state. /// /// Used only to feed `PlayGameObserver::on_state_snapshot`, which keeps its /// `&GameState` signature so the cross-language trace observer in @@ -47,7 +48,7 @@ fn format_action(_board_size: i32, row: i32, col: i32, action_type: i32) -> Stri /// observer. fn compact_to_game_state( mechanics: &QGameMechanics, - data: u64, + data: CompactState, board_size: i32, max_walls: i32, ) -> GameState { @@ -85,7 +86,7 @@ pub struct GameResult { pub replay_items: Vec, } -/// Play a complete game between two agents using the compact `u64` state. +/// Play a complete game between two agents using the compact state. /// /// `agent_p1` controls player 0 and `agent_p2` controls player 1. /// Player 0 moves first. Action selection runs in original orientation; any @@ -177,7 +178,7 @@ pub fn play_game( player: current_player, }); - // Apply action on the canonical u64 state + // Apply action on the canonical compact state mechanics.apply_action_index(&mut data, action_idx); if trace { @@ -244,7 +245,7 @@ mod tests { impl ActionSelector for FirstValidAgent { fn select_action( &mut self, - _data: u64, + _data: CompactState, _mechanics: &QGameMechanics, action_mask: &[bool], ) -> anyhow::Result<(usize, Vec)> { diff --git a/deep_quoridor/rust/src/grid_helpers.rs b/deep_quoridor/rust/src/grid_helpers.rs index b240e92b..96ad2242 100644 --- a/deep_quoridor/rust/src/grid_helpers.rs +++ b/deep_quoridor/rust/src/grid_helpers.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use crate::compact::q_bit_repr::{WALL_HORIZONTAL, WALL_VERTICAL}; +use crate::compact::q_bit_repr::{CompactState, WALL_HORIZONTAL, WALL_VERTICAL}; use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; use crate::grid::CELL_WALL; @@ -65,14 +65,14 @@ pub fn grid_game_state_to_resnet_input(state: &GameState) -> ndarray::Array4 ndarray::Array4 { let repr = mechanics.repr(); let bs = repr.board_size(); @@ -136,8 +136,7 @@ mod tests { #[test] fn test_compact_resnet_matches_gamestate_resnet_initial() { - // QBitRepr packs state into a u64; only smaller boards fit. - for bs in [3, 5] { + for bs in [3, 5, 9] { let state = GameState::new(bs, 3); let mech = QGameMechanics::new(bs as usize, 3, 200); let data = mech.create_initial_state(); diff --git a/deep_quoridor/rust/src/lib.rs b/deep_quoridor/rust/src/lib.rs index 83313011..24aa869a 100644 --- a/deep_quoridor/rust/src/lib.rs +++ b/deep_quoridor/rust/src/lib.rs @@ -5,6 +5,8 @@ use numpy::{ }; #[cfg(feature = "python")] use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::PyBytes; pub mod actions; pub mod compact; @@ -454,6 +456,20 @@ fn policy_db_lookup<'py>( } } +/// Decode a Python bytes-like state into a CompactState. +#[cfg(feature = "python")] +fn state_from_bytes(state: &[u8]) -> PyResult { + if state.len() != 24 { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "state must be 24 bytes, got {}", + state.len() + ))); + } + let mut buf = [0u8; 24]; + buf.copy_from_slice(state); + Ok(compact::q_bit_repr::CompactState::from_bytes(buf)) +} + /// Convert a compact state blob to full game state arrays. /// /// Returns (grid, player_positions, walls_remaining, old_style_walls, current_player, completed_steps). @@ -461,20 +477,21 @@ fn policy_db_lookup<'py>( #[pyfunction] fn compact_state_to_game_state<'py>( py: Python<'py>, - state: u64, + state: Vec, board_size: usize, max_walls: usize, max_steps: usize, -) -> ( +) -> PyResult<( Bound<'py, PyArray2>, Bound<'py, PyArray2>, Bound<'py, PyArray1>, Bound<'py, PyArray3>, i32, i32, -) { +)> { use compact::q_bit_repr::QBitRepr; + let state = state_from_bytes(&state)?; let repr = QBitRepr::new(board_size, max_walls, max_steps); let grid = repr.to_grid(state); @@ -498,30 +515,33 @@ fn compact_state_to_game_state<'py>( } } - ( + Ok(( PyArray2::from_owned_array_bound(py, grid), PyArray2::from_owned_array_bound(py, player_positions), PyArray1::from_owned_array_bound(py, walls_remaining), PyArray3::from_owned_array_bound(py, old_style_walls), current_player, completed_steps, - ) + )) } /// Return all child states reachable from a compact state. /// /// Returns a list of (row, col, action_type, child_state_bytes) tuples where -/// action_type is 0=vertical wall, 1=horizontal wall, 2=pawn move. +/// action_type is 0=vertical wall, 1=horizontal wall, 2=pawn move and +/// `child_state_bytes` is a 24-byte buffer. #[cfg(feature = "python")] #[pyfunction] fn get_compact_child_states( - state: u64, + py: Python<'_>, + state: Vec, board_size: usize, max_walls: usize, max_steps: usize, -) -> Vec<(usize, usize, usize, u64)> { +) -> PyResult)>> { use compact::q_game_mechanics::QGameMechanics; + let state = state_from_bytes(&state)?; let mechanics = QGameMechanics::new(board_size, max_walls, max_steps); let current_player = mechanics.repr().get_current_player(state); let mut data = state; @@ -534,7 +554,8 @@ fn get_compact_child_states( let mut child = data; mechanics.execute_move(&mut child, current_player, row, col); mechanics.switch_player(&mut child); - children.push((row, col, 2usize, child)); + let child_bytes = PyBytes::new_bound(py, &child.to_bytes()).unbind(); + children.push((row, col, 2usize, child_bytes)); } // Wall placements (action_type = 0 or 1) @@ -543,10 +564,11 @@ fn get_compact_child_states( let mut child = data; mechanics.execute_wall_placement(&mut child, current_player, row, col, orientation); mechanics.switch_player(&mut child); - children.push((row, col, orientation, child)); + let child_bytes = PyBytes::new_bound(py, &child.to_bytes()).unbind(); + children.push((row, col, orientation, child_bytes)); } - children + Ok(children) } /// Python wrapper around PolicyDb for database access from Python. @@ -587,18 +609,41 @@ impl PyPolicyDb { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}"))) } - /// Fetch (state, value) tuples by rowid. State is a u64 packed integer. - fn fetch_states_by_rowid(&self, rowids: Vec) -> PyResult> { - self.db + /// Fetch (state, value) tuples by rowid. State is a 24-byte buffer. + fn fetch_states_by_rowid( + &self, + py: Python<'_>, + rowids: Vec, + ) -> PyResult, i32)>> { + let rows = self + .db .fetch_states_by_rowid(&rowids) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}"))) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?; + Ok(rows + .into_iter() + .map(|(s, v)| (PyBytes::new_bound(py, &s.to_bytes()).unbind(), v)) + .collect()) } - /// Look up (state, value) for the given states. - fn lookup_values_by_state(&self, states: Vec) -> PyResult> { - self.db - .lookup_values_by_state(&states) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}"))) + /// Look up (state, value) for the given states. Each input state is a + /// 24-byte buffer. + fn lookup_values_by_state( + &self, + py: Python<'_>, + states: Vec>, + ) -> PyResult, i32)>> { + let parsed: Vec = states + .iter() + .map(|s| state_from_bytes(s)) + .collect::>()?; + let rows = self + .db + .lookup_values_by_state(&parsed) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?; + Ok(rows + .into_iter() + .map(|(s, v)| (PyBytes::new_bound(py, &s.to_bytes()).unbind(), v)) + .collect()) } /// Look up action values for a compact state. @@ -606,7 +651,11 @@ impl PyPolicyDb { /// Returns None if no valid actions, otherwise returns /// (actions, values) where actions is a list of (row, col, action_type) /// and values are from the acting player's perspective. - fn lookup_action_values(&self, state: u64) -> PyResult, Vec)>> { + fn lookup_action_values( + &self, + state: Vec, + ) -> PyResult, Vec)>> { + let state = state_from_bytes(&state)?; self.db .lookup_action_values(state) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}"))) @@ -617,14 +666,15 @@ impl PyPolicyDb { #[cfg(feature = "python")] #[pyfunction] fn compact_state_display( - state: u64, + state: Vec, board_size: usize, max_walls: usize, max_steps: usize, -) -> String { +) -> PyResult { use compact::q_bit_repr::QBitRepr; + let state = state_from_bytes(&state)?; let repr = QBitRepr::new(board_size, max_walls, max_steps); - repr.display(state) + Ok(repr.display(state)) } /// A Python module implemented in Rust. diff --git a/deep_quoridor/rust/src/rotation.rs b/deep_quoridor/rust/src/rotation.rs index 543688e7..64c0b159 100644 --- a/deep_quoridor/rust/src/rotation.rs +++ b/deep_quoridor/rust/src/rotation.rs @@ -10,7 +10,7 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2}; use crate::actions::{action_index_to_action, action_to_index, policy_size, ACTION_MOVE}; -use crate::compact::q_bit_repr::{WALL_HORIZONTAL, WALL_VERTICAL}; +use crate::compact::q_bit_repr::{CompactState, WALL_HORIZONTAL, WALL_VERTICAL}; use crate::compact::q_game_mechanics::QGameMechanics; use crate::game_state::GameState; @@ -117,14 +117,14 @@ pub fn remap_mask(mask: &[bool], mapping: &[usize]) -> Vec { out } -/// Build a 180°-rotated u64 game state. +/// Build a 180°-rotated compact game state. /// /// Mirrors `build_rotated_state` semantics for the compact representation: /// - Player positions flipped: `(r,c) -> (bs-1-r, bs-1-c)` for both players. /// - All wall positions flipped: `(r,c,o) -> (ws-1-r, ws-1-c, o)` where `ws = bs-1`. /// Orientation is preserved under 180° rotation. /// - Walls remaining, current player, completed steps: unchanged. -pub fn rotate_compact_state(mechanics: &QGameMechanics, data: u64) -> u64 { +pub fn rotate_compact_state(mechanics: &QGameMechanics, data: CompactState) -> CompactState { let repr = mechanics.repr(); let bs = repr.board_size(); let ws = bs - 1; @@ -256,8 +256,7 @@ mod tests { #[test] fn test_rotate_compact_state_matches_build_rotated_state_initial() { - // QBitRepr packs state into a u64; only smaller boards fit. - for bs in [3, 5] { + for bs in [3, 5, 9] { let state = GameState::new(bs, 3); let rotated_state = build_rotated_state(&state); From 51d7198b81f8add436d99410ec0ea1845b75338f Mon Sep 17 00:00:00 2001 From: Jon Binney Date: Sun, 17 May 2026 19:27:08 -0400 Subject: [PATCH 6/6] Fix rust formatting --- .../rust/src/agents/alphazero/mcts.rs | 4 +++- deep_quoridor/rust/src/compact/policy_db.rs | 23 +++++++------------ deep_quoridor/rust/src/compact/q_bit_repr.rs | 23 +++++++++++++++---- .../rust/src/compact/q_game_mechanics.rs | 8 ++++++- deep_quoridor/rust/src/game_runner.rs | 7 ++---- 5 files changed, 38 insertions(+), 27 deletions(-) diff --git a/deep_quoridor/rust/src/agents/alphazero/mcts.rs b/deep_quoridor/rust/src/agents/alphazero/mcts.rs index aea711ef..cfee49dd 100644 --- a/deep_quoridor/rust/src/agents/alphazero/mcts.rs +++ b/deep_quoridor/rust/src/agents/alphazero/mcts.rs @@ -402,7 +402,9 @@ pub fn search( .iter() .map(|&child_idx| { let child = arena.get(child_idx); - let ai = child.action_index.expect("child node must have action_index"); + let ai = child + .action_index + .expect("child node must have action_index"); let action = action_index_to_action(bs, ai); ChildInfo { action, diff --git a/deep_quoridor/rust/src/compact/policy_db.rs b/deep_quoridor/rust/src/compact/policy_db.rs index 52d919de..dadbe75f 100644 --- a/deep_quoridor/rust/src/compact/policy_db.rs +++ b/deep_quoridor/rust/src/compact/policy_db.rs @@ -114,7 +114,11 @@ fn parse_meta(kv: Option<&Vec>, key: &str) -> Option { fn bytes_to_state(b: &[u8]) -> Result> { if b.len() != STATE_BYTES { - return Err(format!("state column row width is {}, expected {STATE_BYTES}", b.len()).into()); + return Err(format!( + "state column row width is {}, expected {STATE_BYTES}", + b.len() + ) + .into()); } let mut a = [0u8; STATE_BYTES]; a.copy_from_slice(b); @@ -168,16 +172,8 @@ impl PolicyDb { .ok_or_else(|| format!("row group {i} missing statistics on state column"))?; let (min, max) = match stats { Statistics::FixedLenByteArray(s) => ( - slice_to_arr( - s.min_opt() - .ok_or("missing min stat on state")? - .as_ref(), - )?, - slice_to_arr( - s.max_opt() - .ok_or("missing max stat on state")? - .as_ref(), - )?, + slice_to_arr(s.min_opt().ok_or("missing min stat on state")?.as_ref())?, + slice_to_arr(s.max_opt().ok_or("missing max stat on state")?.as_ref())?, ), _ => { return Err( @@ -647,10 +643,7 @@ fn append_batch( } /// Get all valid actions (moves + wall placements) for the current player. -fn get_all_actions( - mechanics: &QGameMechanics, - data: &mut CompactState, -) -> Vec<(u8, u8, u8)> { +fn get_all_actions(mechanics: &QGameMechanics, data: &mut CompactState) -> Vec<(u8, u8, u8)> { let moves = mechanics.get_valid_moves(*data); let mut actions: Vec<(u8, u8, u8)> = moves .into_iter() diff --git a/deep_quoridor/rust/src/compact/q_bit_repr.rs b/deep_quoridor/rust/src/compact/q_bit_repr.rs index f7515fc7..715e9232 100644 --- a/deep_quoridor/rust/src/compact/q_bit_repr.rs +++ b/deep_quoridor/rust/src/compact/q_bit_repr.rs @@ -182,7 +182,8 @@ impl QBitRepr { ) { debug_assert!(bit_offset + num_bits <= self.total_scalar_bits); let mask = (1u64 << num_bits) - 1; - state.scalars = (state.scalars & !(mask << bit_offset)) | ((value as u64 & mask) << bit_offset); + state.scalars = + (state.scalars & !(mask << bit_offset)) | ((value as u64 & mask) << bit_offset); } /// Read a single scalar bit. @@ -204,7 +205,13 @@ impl QBitRepr { } /// Check if a wall is present at the given position - pub fn get_wall(&self, state: CompactState, row: usize, col: usize, orientation: usize) -> bool { + pub fn get_wall( + &self, + state: CompactState, + row: usize, + col: usize, + orientation: usize, + ) -> bool { let wall_index = self.wall_position_to_index(row, col, orientation); self.get_wall_bit(state, wall_index) } @@ -225,7 +232,8 @@ impl QBitRepr { /// Get a player's position pub fn get_player_position(&self, state: CompactState, player: usize) -> (usize, usize) { debug_assert!(player < 2); - let index = self.get_scalar_bits(state, self.player_pos_offsets[player], self.position_bits); + let index = + self.get_scalar_bits(state, self.player_pos_offsets[player], self.position_bits); self.index_to_position(index) } @@ -287,7 +295,11 @@ impl QBitRepr { /// Get the number of completed steps pub fn get_completed_steps(&self, state: CompactState) -> usize { - self.get_scalar_bits(state, self.completed_steps_offset, self.completed_steps_bits) + self.get_scalar_bits( + state, + self.completed_steps_offset, + self.completed_steps_bits, + ) } /// Set the number of completed steps @@ -426,7 +438,8 @@ impl QBitRepr { let mut wall_line = String::new(); for col in 0..self.board_size { // Print horizontal wall below this cell - if col < self.board_size - 1 && self.get_wall(state, row, col, WALL_HORIZONTAL) { + if col < self.board_size - 1 && self.get_wall(state, row, col, WALL_HORIZONTAL) + { wall_line.push('-'); } else if col > 0 && self.get_wall(state, row, col - 1, WALL_HORIZONTAL) { wall_line.push('-'); diff --git a/deep_quoridor/rust/src/compact/q_game_mechanics.rs b/deep_quoridor/rust/src/compact/q_game_mechanics.rs index 4d19d603..ede177eb 100644 --- a/deep_quoridor/rust/src/compact/q_game_mechanics.rs +++ b/deep_quoridor/rust/src/compact/q_game_mechanics.rs @@ -348,7 +348,13 @@ impl QGameMechanics { } /// Execute a move action - pub fn execute_move(&self, data: &mut CompactState, player: usize, dest_row: usize, dest_col: usize) { + pub fn execute_move( + &self, + data: &mut CompactState, + player: usize, + dest_row: usize, + dest_col: usize, + ) { self.repr .set_player_position(data, player, dest_row, dest_col); } diff --git a/deep_quoridor/rust/src/game_runner.rs b/deep_quoridor/rust/src/game_runner.rs index c4d9b3fd..2c622cfc 100644 --- a/deep_quoridor/rust/src/game_runner.rs +++ b/deep_quoridor/rust/src/game_runner.rs @@ -103,11 +103,8 @@ pub fn play_game( trace: bool, mut observer: Option<&mut dyn PlayGameObserver>, ) -> anyhow::Result { - let mechanics = QGameMechanics::new( - board_size as usize, - max_walls as usize, - max_steps as usize, - ); + let mechanics = + QGameMechanics::new(board_size as usize, max_walls as usize, max_steps as usize); let mut data = mechanics.create_initial_state(); let (original_to_rotated, _) = create_rotation_mapping(board_size);