Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"source.organizeImports": "explicit"
},
"rust-analyzer.linkedProjects": [
"./deep_quoridor/rust/Cargo.toml"
"${workspaceFolder}/deep_quoridor/rust/Cargo.toml"
],
"python-envs.pythonProjects": [],
"chat.tools.terminal.outputLocation": "terminal",
Expand Down
4 changes: 3 additions & 1 deletion deep_quoridor/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
gymnasium
maturin
matplotlib
numba
numba<=0.61
numpy
onnx
onnxscript
open_spiel
pettingzoo
prettytable
Expand Down
29 changes: 15 additions & 14 deletions deep_quoridor/rust/src/agents/alphazero/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use anyhow::Result;
use rand::Rng;

use crate::agents::{ActionSelectionTrace, ActionSelector};
use crate::game_state::GameState;
use crate::compact::q_bit_repr::CompactState;
use crate::compact::q_game_mechanics::QGameMechanics;

use super::evaluator::OnnxEvaluator;
use super::mcts::{search, MCTSConfig};
Expand Down Expand Up @@ -118,7 +119,7 @@ pub fn apply_temperature_and_sample(
pub struct AlphaZeroAgent {
evaluator: OnnxEvaluator,
config: AlphaZeroAgentConfig,
visited_states: HashSet<u64>,
visited_states: HashSet<CompactState>,
last_selection_trace: Option<ActionSelectionTrace>,
}

Expand All @@ -144,7 +145,8 @@ impl AlphaZeroAgent {
impl ActionSelector for AlphaZeroAgent {
fn select_action(
&mut self,
state: &GameState,
data: CompactState,
mechanics: &QGameMechanics,
action_mask: &[bool],
) -> Result<(usize, Vec<f32>)> {
// Run MCTS search - only pass visited states when penalization is enabled
Expand All @@ -156,7 +158,8 @@ impl ActionSelector for AlphaZeroAgent {
};
let (children, root_value) = search(
&self.config.mcts,
state.clone(),
data,
mechanics,
&mut self.evaluator,
visited_ref,
)?;
Expand All @@ -166,8 +169,9 @@ impl ActionSelector for AlphaZeroAgent {
let action_indices: Vec<usize> = children.iter().map(|c| c.action_index).collect();

// Determine effective temperature
let completed_steps = mechanics.repr().get_completed_steps(data);
let temperature = if let Some(threshold) = self.config.drop_t_on_step {
if state.completed_steps >= threshold {
if completed_steps >= threshold {
0.0
} else {
self.config.temperature
Expand All @@ -193,16 +197,13 @@ impl ActionSelector for AlphaZeroAgent {
}
}

// Optionally add state to visited set
// Optionally add the resulting state to the visited set.
// Keyed on the compact state including completed_steps; a position seen
// at two different step counts is treated as two states.
if self.config.penalize_visited_states {
// Get the hash of the resulting state after taking the action
let action = children
.iter()
.find(|c| c.action_index == selected_idx)
.map(|c| c.action)
.unwrap_or([0, 0, 2]);
let next_state = state.clone_and_step(action);
self.visited_states.insert(next_state.get_fast_hash());
let mut next_data = data;
mechanics.apply_action_index(&mut next_data, selected_idx);
self.visited_states.insert(next_data);
}

self.last_selection_trace = Some(ActionSelectionTrace {
Expand Down
69 changes: 48 additions & 21 deletions deep_quoridor/rust/src/agents/alphazero/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ use anyhow::{Context, Result};
use ort::session::Session;

use crate::agents::onnx_agent::softmax;
use crate::game_state::GameState;
use crate::grid_helpers::grid_game_state_to_resnet_input;
use crate::rotation::{build_rotated_state, create_rotation_mapping, remap_policy};
use crate::compact::q_bit_repr::CompactState;
use crate::compact::q_game_mechanics::QGameMechanics;
use crate::grid_helpers::compact_state_to_resnet_input;
use crate::rotation::{create_rotation_mapping, remap_mask, remap_policy, rotate_compact_state};

/// Trait for evaluating game positions.
///
/// Returns `(value_for_current_player, masked_softmax_priors)`.
pub trait Evaluator {
fn evaluate(&mut self, state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec<f32>)>;
fn evaluate(
&mut self,
data: CompactState,
mechanics: &QGameMechanics,
action_mask: &[bool],
) -> Result<(f32, Vec<f32>)>;
}

/// ONNX-based evaluator for MCTS.
Expand All @@ -23,7 +29,7 @@ pub trait Evaluator {
/// returning both a value estimate and policy priors.
pub struct OnnxEvaluator {
session: Session,
rotated_to_original_by_board_size: HashMap<i32, Vec<usize>>,
rotation_mappings_by_board_size: HashMap<i32, (Vec<usize>, Vec<usize>)>,
}

/// Deterministic evaluator for cross-language consistency tests.
Expand All @@ -40,32 +46,48 @@ impl OnnxEvaluator {
.context("Failed to load ONNX model")?;
Ok(Self {
session,
rotated_to_original_by_board_size: HashMap::new(),
rotation_mappings_by_board_size: HashMap::new(),
})
}
}

impl Evaluator for OnnxEvaluator {
fn evaluate(&mut self, state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec<f32>)> {
let rotated_to_original = self
.rotated_to_original_by_board_size
.entry(state.board_size)
.or_insert_with(|| create_rotation_mapping(state.board_size).1);
let (work_state, work_action_mask, rotated_to_original) = if state.current_player == 1 {
let rotated_state = build_rotated_state(state);
let mask = rotated_state.get_action_mask();
(rotated_state, mask, Some(rotated_to_original.as_slice()))
fn evaluate(
&mut self,
data: CompactState,
mechanics: &QGameMechanics,
action_mask: &[bool],
) -> Result<(f32, Vec<f32>)> {
let bs = mechanics.repr().board_size() as i32;
let current_player = mechanics.repr().get_current_player(data);

let mappings = self
.rotation_mappings_by_board_size
.entry(bs)
.or_insert_with(|| create_rotation_mapping(bs));
let (orig_to_rot, rot_to_orig) = (&mappings.0, &mappings.1);

// For player 1, the network always sees the board rotated 180° so the
// current player faces downward. The compact rotation matches the tensor
// of `build_rotated_state`, but `QGameMechanics::goal_rows` is owned by
// the mechanics and does not flip — so wall-mask validation on rotated
// data treats walls that block the rotated player's path as legal. Remap
// the (correct) original mask into rotated index space instead.
let (work_data, work_action_mask, rot_to_orig_slice) = if current_player == 1 {
let rotated_data = rotate_compact_state(mechanics, data);
let rotated_mask = remap_mask(action_mask, orig_to_rot);
(rotated_data, rotated_mask, Some(rot_to_orig.as_slice()))
} else {
(state.clone(), action_mask.to_vec(), None)
(data, action_mask.to_vec(), None)
};

// Build ResNet input tensor
let resnet_input = grid_game_state_to_resnet_input(&work_state);
let resnet_input = compact_state_to_resnet_input(mechanics, work_data);

// Convert to flat vec for ORT
let shape = resnet_input.shape().to_vec();
let data: Vec<f32> = resnet_input.iter().copied().collect();
let input_value = ort::value::Value::from_array((shape.as_slice(), data))
let input_data: Vec<f32> = resnet_input.iter().copied().collect();
let input_value = ort::value::Value::from_array((shape.as_slice(), input_data))
.context("Failed to create ONNX input value")?;

// Run inference
Expand All @@ -87,7 +109,7 @@ impl Evaluator for OnnxEvaluator {

// Apply masked softmax to get priors
let priors_work = masked_softmax(policy_logits.1, &work_action_mask);
let priors = if let Some(rot_to_orig) = rotated_to_original {
let priors = if let Some(rot_to_orig) = rot_to_orig_slice {
remap_policy(&priors_work, rot_to_orig)
} else {
priors_work
Expand All @@ -98,7 +120,12 @@ impl Evaluator for OnnxEvaluator {
}

impl Evaluator for UniformMockEvaluator {
fn evaluate(&mut self, _state: &GameState, action_mask: &[bool]) -> Result<(f32, Vec<f32>)> {
fn evaluate(
&mut self,
_data: CompactState,
_mechanics: &QGameMechanics,
action_mask: &[bool],
) -> Result<(f32, Vec<f32>)> {
let valid_count = action_mask.iter().filter(|&&valid| valid).count();
let mut priors = vec![0.0f32; action_mask.len()];
if valid_count > 0 {
Expand Down
Loading
Loading