diff --git a/Cargo.lock b/Cargo.lock index 1a1854d..7f604ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1329,6 +1329,7 @@ dependencies = [ "criterion", "nom", "num_cpus", + "numpy", "parquet", "pgn-reader", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 99462c4..7ded532 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,12 @@ rayon = "1.11" num_cpus = "1.17" arrow-array = "57" pyo3-arrow = "0.15" +numpy = "0.27" +parquet = "57" +arrow = "57" [dev-dependencies] criterion = "0.8" -parquet = "57" -arrow = "57" [[bench]] harness = false diff --git a/README.md b/README.md index b66c9a6..d4c2261 100644 --- a/README.md +++ b/README.md @@ -5,26 +5,45 @@ This project adds Python bindings to [rust-pgn-reader](https://github.com/niklas ## Installing `pip install rust_pgn_reader_python_binding` +## API + +Three entry points are available: + +- `parse_game(pgn)` - Parse a single PGN string +- `parse_games(chunked_array)` - Parse games from a PyArrow ChunkedArray (multithreaded) +- `parse_games_from_strings(pgns)` - Parse a list of PGN strings (multithreaded) + +All return a `ParsedGames` container with flat NumPy arrays, supporting: +- Indexing (`result[i]`), slicing (`result[1:3]`), and iteration (`for game in result`) +- Per-game views (`PyGameView`) with zero-copy array slices +- Position-to-game and move-to-game mapping for ML workflows +- Optional comment storage (`store_comments=True`) +- Optional legal move storage (`store_legal_moves=True`) + ## Benchmarks -Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on an 7800X3D. +Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on a 7800X3D. | Parser | File format | Time | |----------------------------------------------------------------------------|-------------|--------| | [rust-pgn-reader](https://github.com/niklasf/rust-pgn-reader/tree/master) | PGN | 1s | -| rust_pgn_reader_python_binding | PGN | 4.7s | -| rust_pgn_reader_python_binding, parse_game (single_threaded) | parquet | 3.3s | -| rust_pgn_reader_python_binding, parse_games (multithreaded) | parquet | 0.5s | -| rust_pgn_reader_python_binding, parse_game_moves_arrow_chunked_array (multithreaded) | parquet | 0.35s | +| rust_pgn_reader_python_binding, parse_games (multithreaded) | parquet | 0.35s | +| rust_pgn_reader_python_binding, parse_games_from_strings (multithreaded) | parquet | 0.5s | +| rust_pgn_reader_python_binding, parse_game (single-threaded) | parquet | 3.3s | +| rust_pgn_reader_python_binding, parse_game (single-threaded) | PGN | 4.7s | | [chess-library](https://github.com/Disservin/chess-library) | PGN | 2s | | [python-chess](https://github.com/niklasf/python-chess) | PGN | 3+ min | To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: -`python bench_parquet.py` (single-threaded parse_game) +`python src/bench_parse_games.py` (recommended — multithreaded parse_games via Arrow) + +`python src/bench_parse_games_from_strings.py` (multithreaded parse_games_from_strings) + +`python src/bench_parse_game.py` (single-threaded parse_game from parquet) -`python bench_parquet_parallel.py` (multithreaded parse_games) +`python src/bench_parse_game_pgn.py` (single-threaded parse_game from .pgn file) -`python bench_parquet_arrow.py` (multithreaded parse_game_moves_arrow_chunked_array) +`python src/bench_data_access.py 2013-07-train-00000-of-00001.parquet` (parsing + data access + memory) ## Building `maturin develop` @@ -34,12 +53,12 @@ To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: For a more thorough tutorial, follow https://lukesalamone.github.io/posts/how-to-create-rust-python-bindings/ ## Profiling -`py-spy record -s -F -f speedscope --output profile.speedscope -- python ./src/bench_parquet.py` +`py-spy record -s -F -f speedscope --output profile.speedscope -- python ./src/bench_parse_games.py` Linux/WSL-only: -`py-spy record -s -F -n -f speedscope --output profile.speedscope -- python ./src/bench_parquet.py` +`py-spy record -s -F -n -f speedscope --output profile.speedscope -- python ./src/bench_parse_games.py` ## Testing `cargo test` -`python -m unittest src/test.py` \ No newline at end of file +`python -m unittest src/test.py` diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index fb69596..4af6b63 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -1,15 +1,28 @@ +//! Benchmark for PGN parsing API, designed to mirror the Python workflow. +//! +//! `cargo bench --bench parquet_bench` +//! `samply record --rate 10000 cargo bench --bench parquet_bench` + use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use rayon::prelude::*; +use rayon::ThreadPoolBuilder; use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::parse_multiple_games_native; +use rust_pgn_reader_python_binding::{parse_game_to_buffers, Buffers, ParseConfig}; + +const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; -pub fn bench_parquet() { - let file_path = "2013-07-train-00000-of-00001.parquet"; +/// Chunk multiplier for explicit chunking. +/// 1 = exactly num_threads chunks (minimal overhead) +/// Higher values provide better load balancing at cost of more buffers. +const CHUNK_MULTIPLIER: usize = 1; - // Open the Parquet file +/// Read parquet file and return the raw Arrow StringArrays. +/// This preserves Arrow's memory layout for zero-copy string access. +fn read_parquet_to_string_arrays(file_path: &str) -> Vec { let file = File::open(Path::new(file_path)).expect("Unable to open file"); let builder = ParquetRecordBatchReaderBuilder::try_new(file) .expect("Failed to create ParquetRecordBatchReaderBuilder"); @@ -17,47 +30,125 @@ pub fn bench_parquet() { .build() .expect("Failed to build ParquetRecordBatchReader"); - // Process record batches - let mut movetexts = Vec::new(); + let mut arrays = Vec::new(); while let Some(batch) = reader .next() .transpose() .expect("Error reading record batch") { - // Extract "movetext" column from the record batch if let Some(array) = batch .column_by_name("movetext") .and_then(|col| col.as_any().downcast_ref::()) { - for i in 0..array.len() { - if array.is_valid(i) { - movetexts.push(array.value(i).to_string()); - } - } + arrays.push(array.clone()); } else { panic!("movetext column not found or not a StringArray"); } } + arrays +} + +/// Extract &str slices from Arrow StringArrays (zero-copy). +fn extract_str_slices<'a>(arrays: &'a [StringArray]) -> Vec<&'a str> { + let total_len: usize = arrays.iter().map(|a| a.len()).sum(); + let mut slices = Vec::with_capacity(total_len); + + for array in arrays { + for i in 0..array.len() { + if array.is_valid(i) { + slices.push(array.value(i)); + } + } + } + slices +} - println!("Read {} rows.", movetexts.len()); - // Measure start time +/// Benchmark the parsing API workflow. +/// +/// 1. Read parquet to Arrow arrays +/// 2. Extract &str slices from StringArray +/// 3. Parse in parallel with explicit chunking (par_chunks) -> fixed number of Buffers +/// +/// No merge step - the chunked architecture keeps per-thread buffers as-is. +pub fn bench_parse_api() { + let config = ParseConfig { + store_comments: false, + store_legal_moves: false, + }; + + // Step 1: Read parquet to Arrow StringArrays + let arrays = read_parquet_to_string_arrays(FILE_PATH); + + // Step 2: Extract &str slices (zero-copy) + let pgn_slices = extract_str_slices(&arrays); + println!("Read {} games from parquet.", pgn_slices.len()); + + // Step 3: Build thread pool and compute capacity estimates + let num_threads = num_cpus::get(); + let n_games = pgn_slices.len(); + let moves_per_game = 70; + + // Calculate chunk size for explicit chunking + let num_chunks = num_threads * CHUNK_MULTIPLIER; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; + let chunk_size = chunk_size.max(1); + let games_per_chunk = chunk_size; + + println!( + "Using {} threads, {} chunks, {} games/chunk", + num_threads, num_chunks, games_per_chunk + ); + + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .expect("Failed to build Rayon thread pool"); + + // Step 4: Parse in parallel using par_chunks let start = Instant::now(); - let result = parse_multiple_games_native(&movetexts, None, false); + let chunk_results: Vec = thread_pool.install(|| { + pgn_slices + .par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); + for &pgn in chunk { + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); + } + buffers + }) + .collect() + }); - let duration = start.elapsed(); - println!("Time taken: {:?}", duration); + let duration_parallel = start.elapsed(); + println!("Parallel parsing time: {:?}", duration_parallel); + println!( + "Created {} Buffers chunks (no merge needed)", + chunk_results.len() + ); - match result { - Ok(parsed) => println!("Parsed {} games.", parsed.len()), - Err(err) => eprintln!("Error parsing games: {}", err), - } + // Compute totals from chunks + let total_games: usize = chunk_results.iter().map(|b| b.num_games()).sum(); + let total_positions: usize = chunk_results.iter().map(|b| b.total_positions()).sum(); + + let duration_total = start.elapsed(); + println!("Total time (parsing, no merge): {:?}", duration_total); + println!( + "Parsed {} games, {} total positions.", + total_games, total_positions + ); - let duration2 = start.elapsed(); + // Measure cleanup time + let drop_start = Instant::now(); + drop(chunk_results); + let drop_duration = drop_start.elapsed(); - println!("Time after checks: {:?}", duration2); + let total_duration = start.elapsed(); + println!("Cleanup time (drop): {:?}", drop_duration); + println!("Total time (parsing + cleanup): {:?}", total_duration); } fn main() { - bench_parquet(); + println!("=== Parse API (Buffers) ===\n"); + bench_parse_api(); } diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index d4745eb..b45543b 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -1,54 +1,384 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Iterator, Union, overload import pyarrow +import numpy as np +import numpy.typing as npt +from numpy.typing import NDArray -class PyUciMove: - from_square: int - to_square: int - promotion: Optional[int] +class PyGameView: + """Zero-copy view into a single game's data within a ParsedGames result. - def __init__( - self, from_square: int, to_square: int, promotion: Optional[int] - ) -> None: ... + Board indexing note: Boards use square indexing (a1=0, h8=63). + To convert to rank/file: + rank = square // 8 + file = square % 8 + """ + + def __len__(self) -> int: + """Number of moves in this game.""" + ... + + @property + def num_positions(self) -> int: + """Number of positions recorded for this game.""" + ... + + # === Board state views === + + @property + def boards(self) -> NDArray[np.uint8]: + """Board positions, shape (num_positions, 8, 8).""" + ... + + @property + def initial_board(self) -> NDArray[np.uint8]: + """Initial board position, shape (8, 8).""" + ... + + @property + def final_board(self) -> NDArray[np.uint8]: + """Final board position, shape (8, 8).""" + ... + + @property + def castling(self) -> NDArray[np.bool_]: + """Castling rights [K,Q,k,q], shape (num_positions, 4).""" + ... + + @property + def en_passant(self) -> NDArray[np.int8]: + """En passant file (-1 if none), shape (num_positions,).""" + ... + + @property + def halfmove_clock(self) -> NDArray[np.uint8]: + """Halfmove clock, shape (num_positions,).""" + ... + + @property + def turn(self) -> NDArray[np.bool_]: + """Side to move (True=white), shape (num_positions,).""" + ... + + # === Move views === + + @property + def from_squares(self) -> NDArray[np.uint8]: + """From squares, shape (num_moves,).""" + ... + + @property + def to_squares(self) -> NDArray[np.uint8]: + """To squares, shape (num_moves,).""" + ... + + @property + def promotions(self) -> NDArray[np.int8]: + """Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,).""" + ... + + @property + def clocks(self) -> NDArray[np.float32]: + """Clock times in seconds (NaN if missing), shape (num_moves,).""" + ... + + @property + def evals(self) -> NDArray[np.float32]: + """Engine evals (NaN if missing), shape (num_moves,).""" + ... + + # === Per-game metadata === + + @property + def headers(self) -> Dict[str, str]: + """Raw PGN headers as dict.""" + ... + + @property + def outcome(self) -> Optional[str]: + """Game outcome from movetext: 'White', 'Black', 'Draw', 'Unknown', or None.""" + ... + + @property + def is_checkmate(self) -> bool: + """Final position is checkmate.""" + ... + + @property + def is_stalemate(self) -> bool: + """Final position is stalemate.""" + ... + + @property + def is_insufficient(self) -> Tuple[bool, bool]: + """Insufficient material (white, black).""" + ... + + @property + def legal_move_count(self) -> int: + """Legal move count in final position.""" + ... + + @property + def is_valid(self) -> bool: + """Whether game parsed successfully.""" + ... + + @property + def is_game_over(self) -> bool: + """Whether the game is over (checkmate, stalemate, or both sides insufficient).""" + ... + + @property + def comments(self) -> List[Optional[str]]: + """Raw text comments per move (only populated when store_comments=True).""" + ... + + @property + def legal_moves(self) -> List[List[Tuple[int, int, int]]]: + """Legal moves at each position (only populated when store_legal_moves=True). + Each entry is a list of (from_square, to_square, promotion) tuples.""" + ... + + # === Convenience methods === + + def move_uci(self, move_idx: int) -> str: + """Get UCI string for move at index.""" + ... + + def moves_uci(self) -> List[str]: + """Get all moves as UCI strings.""" + ... + + def __repr__(self) -> str: ... + +class ParsedGamesIter: + """Iterator over games in a ParsedGames result.""" + + def __iter__(self) -> "ParsedGamesIter": ... + def __next__(self) -> PyGameView: ... + +class PyChunkView: + """View into a single chunk's raw numpy arrays. + + Access via ``parsed_games.chunks[i]``. Each chunk corresponds to one + parsing thread's output. Use this for advanced access patterns like + manual concatenation or custom batching. + """ + + @property + def num_games(self) -> int: ... + @property + def num_moves(self) -> int: ... + @property + def num_positions(self) -> int: ... @property - def get_from_square_name(self) -> str: ... + def boards(self) -> NDArray[np.uint8]: + """Board positions, shape (N_positions, 8, 8), dtype uint8.""" + ... @property - def get_to_square_name(self) -> str: ... + def castling(self) -> NDArray[np.bool_]: + """Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool.""" + ... @property - def get_promotion_name(self) -> Optional[str]: ... - def __str__(self) -> str: ... + def en_passant(self) -> NDArray[np.int8]: ... + @property + def halfmove_clock(self) -> NDArray[np.uint8]: ... + @property + def turn(self) -> NDArray[np.bool_]: ... + @property + def from_squares(self) -> NDArray[np.uint8]: ... + @property + def to_squares(self) -> NDArray[np.uint8]: ... + @property + def promotions(self) -> NDArray[np.int8]: ... + @property + def clocks(self) -> NDArray[np.float32]: ... + @property + def evals(self) -> NDArray[np.float32]: ... + @property + def move_offsets(self) -> NDArray[np.uint32]: ... + @property + def position_offsets(self) -> NDArray[np.uint32]: ... + @property + def is_checkmate(self) -> NDArray[np.bool_]: ... + @property + def is_stalemate(self) -> NDArray[np.bool_]: ... + @property + def is_insufficient(self) -> NDArray[np.bool_]: ... + @property + def legal_move_count(self) -> NDArray[np.uint16]: ... + @property + def valid(self) -> NDArray[np.bool_]: ... + @property + def headers(self) -> List[Dict[str, str]]: ... + @property + def outcome(self) -> List[Optional[str]]: ... + @property + def comments(self) -> List[Optional[str]]: ... + @property + def legal_move_from_squares(self) -> NDArray[np.uint8]: ... + @property + def legal_move_to_squares(self) -> NDArray[np.uint8]: ... + @property + def legal_move_promotions(self) -> NDArray[np.int8]: ... + @property + def legal_move_offsets(self) -> NDArray[np.uint32]: ... def __repr__(self) -> str: ... -class PositionStatus: - is_checkmate: bool - is_stalemate: bool - legal_move_count: int - is_game_over: bool - insufficient_material: Tuple[bool, bool] - turn: bool - -class MoveExtractor: - moves: List[PyUciMove] - valid_moves: bool - comments: List[Optional[str]] - evals: List[Optional[float]] - clock_times: List[Optional[Tuple[int, int, float]]] - outcome: Optional[str] - headers: List[Tuple[str, str]] - castling_rights: List[Optional[Tuple[bool, bool, bool, bool]]] - position_status: Optional[PositionStatus] - - def __init__(self, store_legal_moves: bool = False) -> None: ... - def turn(self) -> bool: ... - def update_position_status(self) -> None: ... - @property - def legal_moves(self) -> List[List[PyUciMove]]: ... - -def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ... +class ParsedGames: + """Chunked container for parsed chess games, optimized for ML training. + + Internally stores data in multiple chunks (one per parsing thread) to + avoid the cost of merging. Per-game access is O(log(num_chunks)) via + binary search on precomputed boundaries. + + Board layout: + Boards use square indexing: a1=0, b1=1, ..., h8=63 + Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk + """ + + # === Computed properties === + + @property + def num_games(self) -> int: + """Number of games in the result.""" + ... + + @property + def num_moves(self) -> int: + """Total number of moves across all games.""" + ... + + @property + def num_positions(self) -> int: + """Total number of board positions recorded.""" + ... + + @property + def num_chunks(self) -> int: + """Number of internal chunks.""" + ... + + # === Escape hatch: raw chunk access === + + @property + def chunks(self) -> List[PyChunkView]: + """Access raw per-chunk data. + + Each chunk corresponds to one parsing thread's output. Use this + for advanced access patterns like manual concatenation. + """ + ... + + # === Sequence protocol === + + def __len__(self) -> int: + """Number of games in the result.""" + ... + + @overload + def __getitem__(self, idx: int) -> PyGameView: ... + @overload + def __getitem__(self, idx: slice) -> List[PyGameView]: ... + def __getitem__( + self, idx: Union[int, slice] + ) -> Union[PyGameView, List[PyGameView]]: + """Access game(s) by index or slice.""" + ... + + def __iter__(self) -> ParsedGamesIter: + """Iterate over all games.""" + ... + + # === Mapping utilities === + + def position_to_game(self, position_indices: npt.ArrayLike) -> NDArray[np.int64]: + """Map global position indices to game indices. + + Useful after shuffling/sampling positions to look up game metadata. + + Args: + position_indices: Array of indices into the global position space. + Accepts any integer dtype; int64 is optimal (avoids conversion). + + Returns: + Array of game indices (same shape as input) + """ + ... + + def move_to_game(self, move_indices: npt.ArrayLike) -> NDArray[np.int64]: + """Map global move indices to game indices. + + Args: + move_indices: Array of indices into the global move space. + Accepts any integer dtype; int64 is optimal (avoids conversion). + + Returns: + Array of game indices (same shape as input) + """ + ... + +def parse_game( + pgn: str, + store_comments: bool = False, + store_legal_moves: bool = False, +) -> ParsedGames: + """Parse a single PGN game string. + + Convenience wrapper for parsing a single game. Returns a ParsedGames + container with one game. + + Args: + pgn: PGN game string + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) + + Returns: + ParsedGames object containing the parsed game + """ + ... + def parse_games( - pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False -) -> List[MoveExtractor]: ... -def parse_game_moves_arrow_chunked_array( pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None, + chunk_multiplier: Optional[int] = None, + store_comments: bool = False, store_legal_moves: bool = False, -) -> List[MoveExtractor]: ... +) -> ParsedGames: + """Parse chess games from a PyArrow ChunkedArray into flat NumPy arrays. + + This API is optimized for ML training pipelines, returning flat NumPy arrays + that can be efficiently batched and processed. + + Args: + pgn_chunked_array: PyArrow ChunkedArray containing PGN strings + num_threads: Number of threads for parallel parsing (default: all CPUs) + chunk_multiplier: Multiplier for number of chunks (default: 1) + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) + + Returns: + ParsedGames object containing flat arrays and iteration support + """ + ... + +def parse_games_from_strings( + pgns: List[str], + num_threads: Optional[int] = None, + store_comments: bool = False, + store_legal_moves: bool = False, +) -> ParsedGames: + """Parse multiple PGN game strings in parallel. + + Convenience wrapper for when you have a list of strings rather than an Arrow array. + + Args: + pgns: List of PGN game strings + num_threads: Number of threads for parallel parsing (default: all CPUs) + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) + + Returns: + ParsedGames object containing flat arrays and iteration support + """ + ... diff --git a/src/bench_data_access.py b/src/bench_data_access.py new file mode 100644 index 0000000..19b9f68 --- /dev/null +++ b/src/bench_data_access.py @@ -0,0 +1,95 @@ +""" +Benchmark for parse_games() — parsing speed and data access patterns. + +Usage: + python bench_data_access.py 2013-07-train-00000-of-00001.parquet +""" + +import sys +import time + +import numpy as np +import pyarrow.parquet as pq + +import rust_pgn_reader_python_binding as pgn + + +def main(): + if len(sys.argv) < 2: + print("Usage: python bench_data_access.py ") + return 1 + + file_path = sys.argv[1] + print(f"Loading from: {file_path}") + + pf = pq.ParquetFile(file_path) + chunked_array = pf.read(columns=["movetext"]).column("movetext") + print(f"Loaded {len(chunked_array):,} games") + + # Warmup + _ = pgn.parse_games(chunked_array) + + # Timed parse + start = time.perf_counter() + result = pgn.parse_games(chunked_array) + elapsed = time.perf_counter() - start + + print(f"\nParsing: {elapsed:.3f}s") + print( + f" {result.num_games:,} games, {result.num_moves:,} moves, {result.num_positions:,} positions" + ) + print(f" {result.num_games / elapsed:,.0f} games/sec") + print(f" {result.num_chunks} chunks") + + # Data access: chunk-level array access + start = time.perf_counter() + for chunk in result.chunks: + _ = chunk.boards.sum() + _ = chunk.from_squares.sum() + _ = chunk.to_squares.sum() + elapsed = time.perf_counter() - start + print(f"\nChunk array access: {elapsed:.3f}s") + + # Data access: per-game views + n_access = min(1000, result.num_games) + start = time.perf_counter() + for i in range(n_access): + _ = result[i].boards + elapsed = time.perf_counter() - start + print( + f"Per-game access ({n_access} games): {elapsed:.3f}s ({elapsed / n_access * 1000:.3f}ms/game)" + ) + + # Position-to-game mapping + indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) + start = time.perf_counter() + _ = result.position_to_game(indices) + elapsed = time.perf_counter() - start + print(f"Position-to-game (1000 lookups): {elapsed * 1000:.3f}ms") + + # Memory usage + total_bytes = 0 + for chunk in result.chunks: + total_bytes += ( + chunk.boards.nbytes + + chunk.castling.nbytes + + chunk.en_passant.nbytes + + chunk.halfmove_clock.nbytes + + chunk.turn.nbytes + + chunk.from_squares.nbytes + + chunk.to_squares.nbytes + + chunk.promotions.nbytes + + chunk.clocks.nbytes + + chunk.evals.nbytes + + chunk.move_offsets.nbytes + + chunk.position_offsets.nbytes + ) + print( + f"\nMemory: {total_bytes / 1024 / 1024:.1f} MB ({total_bytes / result.num_positions:.0f} bytes/position)" + ) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/bench_parquet.py b/src/bench_parse_game.py similarity index 71% rename from src/bench_parquet.py rename to src/bench_parse_game.py index f1b01d1..408ebba 100644 --- a/src/bench_parquet.py +++ b/src/bench_parse_game.py @@ -1,21 +1,19 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime file_path = "2013-07-train-00000-of-00001.parquet" - pf = pq.ParquetFile(file_path) pylist = pf.read(columns=["movetext"]).column("movetext").to_pylist() a = datetime.now() for row in pylist: - extractor = rust_pgn_reader_python_binding.parse_game(row) - moves = extractor.moves - comments = extractor.comments + result = rust_pgn_reader_python_binding.parse_game(row) + game = result[0] + moves = game.moves_uci() b = datetime.now() print(b - a) diff --git a/src/bench.py b/src/bench_parse_game_pgn.py similarity index 61% rename from src/bench.py rename to src/bench_parse_game_pgn.py index 28978b3..e5b91f7 100644 --- a/src/bench.py +++ b/src/bench_parse_game_pgn.py @@ -1,5 +1,4 @@ import rust_pgn_reader_python_binding - from datetime import datetime @@ -8,12 +7,12 @@ def split_pgn(file_path): with open(file_path, "r", encoding="utf-8") as file: game_lines = [] for line in file: - if line.strip() == "" and game_lines: # End of a game + if line.strip() == "" and game_lines: yield "".join(game_lines) - game_lines = [] # Reset for the next game + game_lines = [] else: game_lines.append(line) - if game_lines: # Yield the last game if the file doesn't end with a blank line + if game_lines: yield "".join(game_lines) @@ -21,9 +20,9 @@ def split_pgn(file_path): a = datetime.now() for game_pgn in split_pgn(file_path): - extractor = rust_pgn_reader_python_binding.parse_game(game_pgn) - moves = extractor.moves - comments = extractor.comments + result = rust_pgn_reader_python_binding.parse_game(game_pgn) + game = result[0] + moves = game.moves_uci() b = datetime.now() print(b - a) diff --git a/src/bench_parquet_arrow.py b/src/bench_parse_games.py similarity index 73% rename from src/bench_parquet_arrow.py rename to src/bench_parse_games.py index 59a8caf..dff4c94 100644 --- a/src/bench_parquet_arrow.py +++ b/src/bench_parse_games.py @@ -1,20 +1,16 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime -file_path = "2013-07-train-00000-of-00001.parquet" +file_path = "2013-07-train-00000-of-00001.parquet" pf = pq.ParquetFile(file_path) - movetext_arrow_array = pf.read(columns=["movetext"]).column("movetext") a = datetime.now() -extractors = rust_pgn_reader_python_binding.parse_game_moves_arrow_chunked_array( - movetext_arrow_array -) +result = rust_pgn_reader_python_binding.parse_games(movetext_arrow_array) b = datetime.now() print(b - a) diff --git a/src/bench_parquet_parallel.py b/src/bench_parse_games_from_strings.py similarity index 80% rename from src/bench_parquet_parallel.py rename to src/bench_parse_games_from_strings.py index 18e504a..f4b8949 100644 --- a/src/bench_parquet_parallel.py +++ b/src/bench_parse_games_from_strings.py @@ -1,18 +1,16 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime file_path = "2013-07-train-00000-of-00001.parquet" - pf = pq.ParquetFile(file_path) pylist = pf.read(columns=["movetext"]).column("movetext").to_pylist() a = datetime.now() -extractors = rust_pgn_reader_python_binding.parse_games(pylist) +result = rust_pgn_reader_python_binding.parse_games_from_strings(pylist) b = datetime.now() print(b - a) diff --git a/src/board_serialization.rs b/src/board_serialization.rs new file mode 100644 index 0000000..53146b6 --- /dev/null +++ b/src/board_serialization.rs @@ -0,0 +1,133 @@ +//! Board state serialization helpers for ParsedGames output. +//! +//! Board encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk +//! This matches python-chess piece_type (1-6) with color offset (+6 for black). +//! +//! Square indexing: a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63 +//! Note: This differs from some Python code that uses [7 - rank, file] indexing. +//! The Python wrapper can transpose if needed. + +use shakmaty::{Chess, Color, EnPassantMode, Position, Role, Square}; + +/// Serialize board position to 64-byte array. +/// Index mapping: square index (a1=0, h8=63) -> piece value (0-12) +/// +/// Optimized to iterate by piece type using bitboards directly, +/// avoiding expensive piece_at() lookups for all 64 squares. +pub fn serialize_board(pos: &Chess) -> [u8; 64] { + let mut board = [0u8; 64]; + let b = pos.board(); + + for role in Role::ALL { + let piece_val = role as u8; // Role enum is 1-6 + for sq in b.by_role(role) & b.white() { + board[sq as usize] = piece_val; + } + for sq in b.by_role(role) & b.black() { + board[sq as usize] = piece_val + 6; + } + } + + board +} + +/// Get en passant file (0-7) or -1 if none. +/// Uses Always mode to report the e.p. square whenever a double pawn push occurred, +/// regardless of whether a legal capture is available. +pub fn get_en_passant_file(pos: &Chess) -> i8 { + pos.ep_square(EnPassantMode::Always) + .map(|sq| sq.file() as i8) + .unwrap_or(-1) +} + +/// Get halfmove clock (for 50-move rule). +pub fn get_halfmove_clock(pos: &Chess) -> u8 { + pos.halfmoves().min(255) as u8 +} + +/// Get side to move: true = white, false = black. +pub fn get_turn(pos: &Chess) -> bool { + pos.turn() == Color::White +} + +/// Get castling rights as [K, Q, k, q] (white kingside, white queenside, black kingside, black queenside). +pub fn get_castling_rights(pos: &Chess) -> [bool; 4] { + let rights = pos.castles().castling_rights(); + [ + rights.contains(Square::H1), // White kingside (rook on h1) + rights.contains(Square::A1), // White queenside (rook on a1) + rights.contains(Square::H8), // Black kingside (rook on h8) + rights.contains(Square::A8), // Black queenside (rook on a8) + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialize_initial_board() { + let pos = Chess::default(); + let board = serialize_board(&pos); + + // White pieces on rank 1 (indices 0-7) + assert_eq!(board[0], 4); // a1 = white rook + assert_eq!(board[1], 2); // b1 = white knight + assert_eq!(board[2], 3); // c1 = white bishop + assert_eq!(board[3], 5); // d1 = white queen + assert_eq!(board[4], 6); // e1 = white king + assert_eq!(board[5], 3); // f1 = white bishop + assert_eq!(board[6], 2); // g1 = white knight + assert_eq!(board[7], 4); // h1 = white rook + + // White pawns on rank 2 (indices 8-15) + for i in 8..16 { + assert_eq!(board[i], 1); // white pawn + } + + // Empty squares (ranks 3-6, indices 16-47) + for i in 16..48 { + assert_eq!(board[i], 0); + } + + // Black pawns on rank 7 (indices 48-55) + for i in 48..56 { + assert_eq!(board[i], 7); // black pawn (1 + 6) + } + + // Black pieces on rank 8 (indices 56-63) + assert_eq!(board[56], 10); // a8 = black rook (4 + 6) + assert_eq!(board[57], 8); // b8 = black knight (2 + 6) + assert_eq!(board[58], 9); // c8 = black bishop (3 + 6) + assert_eq!(board[59], 11); // d8 = black queen (5 + 6) + assert_eq!(board[60], 12); // e8 = black king (6 + 6) + assert_eq!(board[61], 9); // f8 = black bishop (3 + 6) + assert_eq!(board[62], 8); // g8 = black knight (2 + 6) + assert_eq!(board[63], 10); // h8 = black rook (4 + 6) + } + + #[test] + fn test_initial_castling_rights() { + let pos = Chess::default(); + let rights = get_castling_rights(&pos); + assert_eq!(rights, [true, true, true, true]); // [K, Q, k, q] + } + + #[test] + fn test_initial_en_passant() { + let pos = Chess::default(); + assert_eq!(get_en_passant_file(&pos), -1); + } + + #[test] + fn test_initial_halfmove_clock() { + let pos = Chess::default(); + assert_eq!(get_halfmove_clock(&pos), 0); + } + + #[test] + fn test_initial_turn() { + let pos = Chess::default(); + assert!(get_turn(&pos)); // White to move + } +} diff --git a/src/example_chess960.py b/src/example_chess960.py index 5105556..c7c0290 100644 --- a/src/example_chess960.py +++ b/src/example_chess960.py @@ -14,15 +14,11 @@ 1.g3 d5 2.d4 g6 3.b3 Nf6 4.Ne3 b6 5.Nh3 Ne6 6.f4 Ng7 7.g4 h6 8.Nf2 Ne6 9.f5 Nf4 10.Nf1 gxf5 11.Qd2 Ne6 12.gxf5 Ng5 13.Ng3 Qd7 14.Rg1 Rg8 15.O-O-O O-O-O 16.Kb1 Kb8 17.Bb2 h5 18.Qf4 Bb7 19.h4 Nge4 20.Nfxe4 dxe4 21.Nxe4 Nxe4 22.Bxe4 Bxe4 23.Qxe4 e6 24.Rxg8 Rxg8 25.fxe6 fxe6 26.Rf1 Qe7 27.e3 Bf6 28.Ba3 Qd8 29.Qxe6 Bxh4 30.Rf5 Bg5 31.d5 Qc8 32.e4 h4 33.Qf7 Be3 34.d6 Rg1+ 35.Rf1 Rxf1+ 36.Qxf1 h3 37.dxc7+ Kxc7 38.Qc4+ Kb7 39.Qf7+ Qc7 40.Qd5+ Qc6 41.Qf7+ Qc7 42.Qd5+ Qc6 43.Qf7+ Qc7 1/2-1/2 {OL: 0} """ -extractor = rust_pgn_reader_python_binding.parse_game(pgn_moves) +result = rust_pgn_reader_python_binding.parse_game(pgn_moves) +game = result[0] -print("moves", extractor.moves) -print("comments", extractor.comments) -print("valid", extractor.valid_moves) -# print(extractor.evals) -print("clock", extractor.clock_times) -# print(extractor.outcome) -# print(extractor.position_status.is_checkmate) -# print(extractor.position_status.is_stalemate) -# print(extractor.position_status.is_game_over) -# print(extractor.position_status.legal_move_count) +print("Moves (UCI):", game.moves_uci()) +print("Valid:", game.is_valid) +print("Outcome:", game.outcome) +print("Num positions:", game.num_positions) +print("Clocks (first 5):", game.clocks[:5].tolist()) diff --git a/src/example_complex.py b/src/example_complex.py new file mode 100644 index 0000000..f56fb2a --- /dev/null +++ b/src/example_complex.py @@ -0,0 +1,41 @@ +import rust_pgn_reader_python_binding +import numpy as np + +pgn_moves = """ +[Event "Casual Correspondence game"] +[Site "https:///gdEj47Dv"] +[Date ""] +[White "lichess AI level 8"] +[Black "TherealARB"] +[Result "0-1"] +[UTCDate ""] +[UTCTime ":15"] +[WhiteElo "?"] +[BlackElo "1500"] +[Variant "Standard"] +[TimeControl "-"] +[ECO "C00"] +[Opening "Rat Defense: Small Center Defense"] +[Termination "Normal"] +[Annotator ""] + +1. e4 { [%eval ] } 1... e6 { [%eval ] } +2. d4 { [%eval ] } 2... d6?! { (0 → ) Inaccuracy. d5 was best. } { [%eval ] } { C00 Rat Defense: Small Center Defense } (2... d5 3. Nc3 Nf6 4. e5 Nfd7 5. f4 c5 6. Nf3 Nc6 7. Be3) +3. c4?! { ( → ) Inaccuracy. Bd3 was best. } { [%eval ] } (3. Bd3 c5 4. dxc5 dxc5 5. Nf3 Ne7 6. Qe2 Bd7 7. Nc3 Nec6) 3... h6? { ( → ) Mistake. d5 was best. } { [%eval ] } (3... d5 4. cxd5 exd5 5. exd5 Nf6 6. Nc3 Be7 7. Nf3 O-O 8. Bc4) +4. Nf3 { [%eval ] } 4... a6 { [%eval ] } +5. Nc3 { [%eval 1] } 5... g6 { [%eval ] } 6. Be3 { [%eval ] } 6... b6 { [%eval ] } 7. Bd3 { [%eval ] } 7... Bg7 { [%eval ] } 8. Qd2 { [%eval ] } 8... Bb7 { [%eval 7] } 9. O-O { [%eval ] } 9... Ne7 { [%eval ] } 10. d5 { [%eval ] } 10... e5 { [%eval ] } 11. g3 { [%eval ] } 11... Nd7 { [%eval 4] } 12. a3 { [%eval ] } 12... g5 { [%eval ] } 13. h4 { [%eval 3] } 13... f6 { [%eval ] } 14. Rfc1 { [%eval ] } 14... Ng6 { [%eval ] } 15. h5 { [%eval ] } 15... Nf4 { [%eval ] } 16. gxf4 { [%eval ] } 16... gxf4 { [%eval ] } 17. Bxf4?! { ( → ) Inaccuracy. Kh1 was best. } { [%eval ] } (17. Kh1 fxe3 18. fxe3 f5 19. Rg1 Bf6 20. exf5 Bg5 21. Ne4 Nf6 22. Qg2 c6 23. Nxf6+ Qxf6) 17... exf4 { [%eval 1] } 18. Qxf4 { [%eval ] } 18... Ne5 { [%eval ] } 19. Be2?! { (0 → ) Inaccuracy. Rd1 was best. } { [%eval ] } (19. Rd1 Nxf3+ 20. Qxf3 O-O 21. Kf1 Bc8 22. Re1 Qe7 23. Re3 Qf7 24. Rae1 Bd7 25. Ne2 f5) 19... Nxf3+ { [%eval 1] } 20. Bxf3 { [%eval 5] } 20... Bc8 { [%eval ] } 21. Re1 { [%eval ] } 21... O-O { [%eval ] } 22. Kh2 { [%eval ] } 22... Rf7 { [%eval ] } 23. Rg1 { [%eval ] } 23... f5 { [%eval ] } 24. Rg6 { [%eval -3] } 24... Kh8 { [%eval ] } 25. Rxh6+ { [%eval ] } 25... Bxh6 { [%eval ] } 26. Qxh6+ { [%eval ] } 26... Rh7 { [%eval ] } 27. Qf4 { [%eval ] } 27... Qf6 { [%eval ] } 28. Re1 { [%eval -7] } 28... Bd7 { [%eval -] } 29. e5 { [%eval 8] } 29... Qg7 { [%eval -5] } 30. exd6 { [%eval ] } 30... Rg8 { [%eval ] } 31. Qg3 { [%eval -4] } 31... Qf6 { [%eval -] } 32. Qf4 { [%eval ] } 32... cxd6 { [%eval ] } 33. Ne2 { [%eval ] } 33... Qe5 { [%eval -5] } 34. b4?! { (-5 → ) Inaccuracy. Rg1 was best. } { [%eval ] } (34. Rg1 Qxf4+ 35. Nxf4 Rxg1 36. Kxg1 Kg7 37. Kf1 b5 38. cxb5 Bxb5+ 39. Ke1 Kf6 40. Kd2 a5) 34... Be8 { [%eval ] } 35. Qxe5+ { [%eval ] } 35... dxe5 { [%eval ] } 36. Ng3 { [%eval ] } 36... f4 { [%eval ] } 37. Rxe5 { [%eval ] } 37... fxg3+ { [%eval ] } 38. fxg3 { [%eval -] } 38... Rhg7 { [%eval ] } 39. g4 { [%eval ] } 39... Bd7 { [%eval ] } 40. h6 { [%eval ] } 40... Rg6 { [%eval ] } 41. g5?! ... +66. Kc2 { [%eval #-25] } 66... b4 { [%eval #-6] } +67. Kd2 { [%eval #-9] } 67... b3 { [%eval #-8] } 68. Bd3 { [%eval #-7] } 68... b2 { [%eval #-7] } 69. Ke3 { [%eval #-6] } 69... Ba2 { [%eval #-11] } 70. Kf2 { [%eval #-8] } 70... Bb1 { [%eval #-6] } 71. Bc4 { [%eval #-6] } 71... Bg6 { [%eval #-4] } 72. Bd3 { [%eval #-4] } 72... Bxd3 { [%eval #-3] } 73. Kf3 { [%eval #-3] } 73... b1=Q { [%eval #-2] } 74. Ke3 { [%eval #-2] } 74... Qf1 { [%eval #-1] } 75. Kd2 { [%eval #-1] } 75... Qae1# { Black wins by checkmate. } 0-1 +""" + +result = rust_pgn_reader_python_binding.parse_game(pgn_moves, store_comments=True) +game = result[0] + +print("Moves (UCI):", game.moves_uci()) +print("Comments (first 5):", game.comments[:5]) +print("Valid:", game.is_valid) +print("Evals (first 10):", game.evals[:10].tolist()) +print("Clocks (first 10):", game.clocks[:10].tolist()) +print("Outcome:", game.outcome) +print("Is checkmate:", game.is_checkmate) +print("Is game over:", game.is_game_over) diff --git a/src/example_parse_games.py b/src/example_parse_games.py new file mode 100644 index 0000000..7981306 --- /dev/null +++ b/src/example_parse_games.py @@ -0,0 +1,169 @@ +""" +Example demonstrating the parse_games() API for ML-optimized PGN parsing. + +This API returns flat NumPy arrays suitable for efficient batch processing +in machine learning pipelines. +""" + +import numpy as np +import pyarrow as pa + +import rust_pgn_reader_python_binding as pgn + + +# Sample PGN games with annotations +sample_pgns = [ + # Game 1: Sicilian Defense with eval/clock annotations + """[Event "Online Blitz"] +[White "Player1"] +[Black "Player2"] +[Result "1-0"] +[WhiteElo "1850"] +[BlackElo "1790"] + +1. e4 { [%eval 0.17] [%clk 0:03:00] } 1... c5 { [%eval 0.19] [%clk 0:03:00] } +2. Nf3 { [%eval 0.25] [%clk 0:02:55] } 2... d6 { [%eval 0.30] [%clk 0:02:58] } +3. d4 { [%eval 0.35] [%clk 0:02:50] } 3... cxd4 { [%eval 0.32] [%clk 0:02:55] } +4. Nxd4 { [%eval 0.28] [%clk 0:02:48] } 4... Nf6 { [%eval 0.25] [%clk 0:02:52] } +5. Nc3 { [%eval 0.30] [%clk 0:02:45] } 1-0""", + # Game 2: Italian Game + """[Event "Club Championship"] +[White "Magnus"] +[Black "Hikaru"] +[Result "0-1"] + +1. e4 e5 2. Nf3 Nc6 3. Bc4 Bc5 4. c3 Nf6 5. d4 exd4 6. cxd4 Bb4+ 0-1""", + # Game 3: Scholar's Mate (checkmate) + """[Event "Beginner Game"] +[White "NewPlayer"] +[Black "Victim"] +[Result "1-0"] + +1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0""", + # Game 4: Queen's Gambit + """[Event "Tournament"] +[White "Carlsen"] +[Black "Caruana"] +[Result "1/2-1/2"] + +1. d4 d5 2. c4 e6 3. Nc3 Nf6 4. Bg5 Be7 5. e3 O-O 1/2-1/2""", +] + + +def main(): + print("=" * 60) + print("parse_games() API Example") + print("=" * 60) + + # Create PyArrow chunked array from PGN strings + chunked_array = pa.chunked_array([pa.array(sample_pgns)]) + + # Parse all games at once - returns flat NumPy arrays + result = pgn.parse_games(chunked_array) + + # === Basic Statistics === + print(f"\n--- Basic Statistics ---") + print(f"Number of games: {result.num_games}") + print(f"Total moves: {result.num_moves}") + print(f"Total positions: {result.num_positions}") + print(f"Number of chunks: {result.num_chunks}") + + # === Chunk Details (escape hatch for raw array access) === + print(f"\n--- Chunk Details ---") + for i, chunk in enumerate(result.chunks): + print( + f" Chunk {i}: {chunk.num_games} games, " + f"{chunk.num_moves} moves, " + f"{chunk.num_positions} positions" + ) + print(f" boards: {chunk.boards.shape} ({chunk.boards.dtype})") + print( + f" from_squares: {chunk.from_squares.shape} ({chunk.from_squares.dtype})" + ) + + # === Iterate Over Games === + print(f"\n--- Game Details ---") + for i, game in enumerate(result): + print( + f"\nGame {i + 1}: {game.headers.get('White', '?')} vs {game.headers.get('Black', '?')}" + ) + print(f" Event: {game.headers.get('Event', 'N/A')}") + print(f" Moves: {len(game)}") + print(f" Positions: {game.num_positions}") + print(f" Valid: {game.is_valid}") + print(f" Checkmate: {game.is_checkmate}") + print(f" Outcome: {game.outcome}") + print( + f" UCI moves: {' '.join(game.moves_uci()[:10])}{'...' if len(game) > 10 else ''}" + ) + + # Show eval annotations if available + valid_evals = game.evals[~np.isnan(game.evals)] + if len(valid_evals) > 0: + print(f" Evals: {valid_evals[:5].tolist()}...") + + # === Direct Array Access for ML === + print(f"\n--- ML-Ready Data Access ---") + + # Access boards via chunks (no single merged array) + # To concatenate all boards: np.concatenate([c.boards for c in result.chunks]) + chunk0_boards = result.chunks[0].boards + print(f"Chunk 0 boards: {chunk0_boards.shape}") + + # Get initial position of first game + game0 = result[0] + initial_board = game0.initial_board + print(f"\nInitial board (Game 1):") + # Print board with piece symbols + piece_chars = " PNBRQKpnbrqk" + for rank in range(7, -1, -1): # Print from rank 8 to rank 1 + row = "" + for file in range(8): + sq_idx = rank * 8 + file + piece = initial_board.flat[sq_idx] + row += piece_chars[piece] + " " + print(f" {rank + 1} | {row}") + print(f" +----------------") + print(f" a b c d e f g h") + + # === Position-to-Game Mapping === + print(f"\n--- Position-to-Game Mapping ---") + # Useful for shuffling positions while keeping track of game metadata + sample_positions = np.array([0, 5, 10, 15, 20], dtype=np.int64) + game_indices = result.position_to_game(sample_positions) + print(f"Position indices: {sample_positions}") + print(f"Game indices: {game_indices}") + + # === Slicing === + print(f"\n--- Slicing ---") + # Get games 1-2 (0-indexed) + subset = result[1:3] + print(f"Slice [1:3]: {len(subset)} games") + for game in subset: + print( + f" - {game.headers.get('White', '?')} vs {game.headers.get('Black', '?')}" + ) + + # === Negative Indexing === + print(f"\n--- Negative Indexing ---") + last_game = result[-1] + print( + f"Last game: {last_game.headers.get('White', '?')} vs {last_game.headers.get('Black', '?')}" + ) + print(f" Moves: {last_game.moves_uci()}") + + # === Checkmate Detection === + print(f"\n--- Checkmate Detection ---") + for i, game in enumerate(result): + if game.is_checkmate: + print(f"Game {i + 1} ended in checkmate!") + print(f" Final position legal moves: {game.legal_move_count}") + print(f" Is game over: {game.is_game_over}") + + print("\n" + "=" * 60) + print("Example complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/example_rowgroup.py b/src/example_rowgroup.py index 9dddc4b..6acddaf 100644 --- a/src/example_rowgroup.py +++ b/src/example_rowgroup.py @@ -12,10 +12,11 @@ pf = pq.ParquetFile(file_path) for i in range(pf.num_row_groups): - table = pf.read_row_group(0, columns=["movetext"]) - extractors = rust_pgn_reader_python_binding.parse_games( - table.column("movetext").to_pylist(), num_threads=4 + table = pf.read_row_group(i, columns=["movetext"]) + result = rust_pgn_reader_python_binding.parse_games( + table.column("movetext"), num_threads=4 ) + print(f"Row group {i}: {result.num_games} games, {result.num_moves} moves") b = datetime.now() print(b - a) diff --git a/src/example_str.py b/src/example_str.py index 58f4ffd..a41bbd9 100644 --- a/src/example_str.py +++ b/src/example_str.py @@ -1,4 +1,5 @@ import rust_pgn_reader_python_binding +import numpy as np pgn_moves = """ 1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } @@ -16,16 +17,16 @@ 13. b3?? { [%eval -4.14] [%clk 0:00:02] } 13... Nf4? { [%eval -2.73] [%clk 0:00:21] } 0-1 """ -extractor = rust_pgn_reader_python_binding.parse_game(pgn_moves) +result = rust_pgn_reader_python_binding.parse_game(pgn_moves, store_comments=True) +game = result[0] -print(extractor.moves) -print(extractor.comments) -print(extractor.valid_moves) -print(extractor.evals) -print(extractor.clock_times) -print(extractor.outcome) -assert extractor.position_status is not None -print(extractor.position_status.is_checkmate) -print(extractor.position_status.is_stalemate) -print(extractor.position_status.is_game_over) -print(extractor.position_status.legal_move_count) +print("Moves (UCI):", game.moves_uci()) +print("Comments:", game.comments) +print("Valid:", game.is_valid) +print("Evals:", game.evals.tolist()) +print("Clocks:", game.clocks.tolist()) +print("Outcome:", game.outcome) +print("Is checkmate:", game.is_checkmate) +print("Is stalemate:", game.is_stalemate) +print("Is game over:", game.is_game_over) +print("Legal move count:", game.legal_move_count) diff --git a/src/lib.rs b/src/lib.rs index d38b337..2a6f5a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,529 +1,53 @@ -use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; use arrow_array::{Array, LargeStringArray, StringArray}; -use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; +use numpy::{PyArray1, PyArrayMethods}; use pyo3::prelude::*; use pyo3_arrow::PyChunkedArray; use rayon::ThreadPoolBuilder; use rayon::prelude::*; -use shakmaty::CastlingMode; -use shakmaty::Color; -use shakmaty::fen::Fen; -use shakmaty::{Chess, Position, Role, Square, uci::UciMove}; -use std::io::Cursor; -use std::ops::ControlFlow; +mod board_serialization; mod comment_parsing; - -// Definition of PyUciMove -#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] -#[derive(Clone, Debug)] -pub struct PyUciMove { - pub from_square: u8, - pub to_square: u8, - pub promotion: Option, -} - -#[pymethods] -impl PyUciMove { - #[new] - fn new(from_square: u8, to_square: u8, promotion: Option) -> Self { - PyUciMove { - from_square, - to_square, - promotion, - } - } - - #[getter] - fn get_from_square_name(&self) -> String { - Square::new(self.from_square as u32).to_string() - } - - #[getter] - fn get_to_square_name(&self) -> String { - Square::new(self.to_square as u32).to_string() - } - - #[getter] - fn get_promotion_name(&self) -> Option { - self.promotion.and_then(|p_u8| { - Role::try_from(p_u8) - .map(|role| format!("{:?}", role)) // Get the debug representation (e.g., "Queen") - .ok() - }) - } - - // __str__ method for Python representation - fn __str__(&self) -> String { - let promo_str = self.promotion.map_or("".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| role.char().to_string()) - .unwrap_or_else(|_| "".to_string()) // Handle potential error if u8 is not a valid Role - }); - format!( - "{}{}{}", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_str - ) - } - - // __repr__ for a more developer-friendly representation - fn __repr__(&self) -> String { - let promo_repr = self.promotion.map_or("None".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| format!("Some('{}')", role.char())) - .unwrap_or_else(|_| format!("Some(InvalidRole({}))", p_u8)) - }); - format!( - "PyUciMove(from_square={}, to_square={}, promotion={})", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_repr - ) - } -} - -#[pyclass] -/// Holds the status of a chess position. -#[derive(Clone)] -pub struct PositionStatus { - #[pyo3(get)] - is_checkmate: bool, - - #[pyo3(get)] - is_stalemate: bool, - - #[pyo3(get)] - legal_move_count: usize, - - #[pyo3(get)] - is_game_over: bool, - - #[pyo3(get)] - insufficient_material: (bool, bool), - - #[pyo3(get)] - turn: bool, -} - -#[pyclass] -/// A Visitor to extract SAN moves and comments from PGN movetext -pub struct MoveExtractor { - #[pyo3(get)] - moves: Vec, - - store_legal_moves: bool, - flat_legal_moves: Vec, - legal_moves_offsets: Vec, - - #[pyo3(get)] - valid_moves: bool, - - #[pyo3(get)] - comments: Vec>, - - #[pyo3(get)] - evals: Vec>, - - #[pyo3(get)] - clock_times: Vec>, - - #[pyo3(get)] - outcome: Option, - - #[pyo3(get)] - headers: Vec<(String, String)>, - - #[pyo3(get)] - castling_rights: Vec>, - - #[pyo3(get)] - position_status: Option, - - pos: Chess, -} - -#[pymethods] -impl MoveExtractor { - #[new] - #[pyo3(signature = (store_legal_moves = false))] - fn new(store_legal_moves: bool) -> MoveExtractor { - MoveExtractor { - moves: Vec::with_capacity(100), - store_legal_moves, - flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves - legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets - pos: Chess::default(), - valid_moves: true, - comments: Vec::with_capacity(100), - evals: Vec::with_capacity(100), - clock_times: Vec::with_capacity(100), - outcome: None, - headers: Vec::with_capacity(10), - castling_rights: Vec::with_capacity(100), - position_status: None, - } - } - - fn turn(&self) -> bool { - match self.pos.turn() { - Color::White => true, - Color::Black => false, - } - } - - fn push_castling_bitboards(&mut self) { - let castling_bitboard = self.pos.castles().castling_rights(); - let castling_rights = ( - castling_bitboard.contains(shakmaty::Square::A1), - castling_bitboard.contains(shakmaty::Square::H1), - castling_bitboard.contains(shakmaty::Square::A8), - castling_bitboard.contains(shakmaty::Square::H8), - ); - - self.castling_rights.push(Some(castling_rights)); - } - - fn push_legal_moves(&mut self) { - // Record the starting offset for the current position's legal moves. - self.legal_moves_offsets.push(self.flat_legal_moves.len()); - - let legal_moves_for_pos = self.pos.legal_moves(); - self.flat_legal_moves.reserve(legal_moves_for_pos.len()); - - for m in legal_moves_for_pos { - let uci_move_obj = UciMove::from_standard(m); - if let UciMove::Normal { - from, - to, - promotion: promo_opt, - } = uci_move_obj - { - self.flat_legal_moves.push(PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }); - } - } - } - - fn update_position_status(&mut self) { - // TODO this checks legal_moves() a bunch of times - self.position_status = Some(PositionStatus { - is_checkmate: self.pos.is_checkmate(), - is_stalemate: self.pos.is_stalemate(), - legal_move_count: self.pos.legal_moves().len(), - is_game_over: self.pos.is_game_over(), - insufficient_material: ( - self.pos.has_insufficient_material(Color::White), - self.pos.has_insufficient_material(Color::Black), - ), - turn: match self.pos.turn() { - Color::White => true, - Color::Black => false, - }, - }); - } - - #[getter] - fn legal_moves(&self) -> Vec> { - let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); - if self.legal_moves_offsets.is_empty() { - return result; - } - - for i in 0..self.legal_moves_offsets.len() - 1 { - let start = self.legal_moves_offsets[i]; - let end = self.legal_moves_offsets[i + 1]; - result.push(self.flat_legal_moves[start..end].to_vec()); - } - - // Handle the last chunk - if let Some(&start) = self.legal_moves_offsets.last() { - result.push(self.flat_legal_moves[start..].to_vec()); - } - - result - } -} - -impl Visitor for MoveExtractor { - type Tags = Vec<(String, String)>; - type Movetext = (); - type Output = bool; - - fn begin_tags(&mut self) -> ControlFlow { - self.headers.clear(); - ControlFlow::Continue(Vec::with_capacity(10)) - } - - fn tag( - &mut self, - tags: &mut Self::Tags, - key: &[u8], - value: RawTag<'_>, - ) -> ControlFlow { - let key_str = String::from_utf8_lossy(key).into_owned(); - let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); - tags.push((key_str, value_str)); - ControlFlow::Continue(()) - } - - fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { - self.headers = tags; - self.moves.clear(); - self.flat_legal_moves.clear(); - self.legal_moves_offsets.clear(); - self.valid_moves = true; - self.comments.clear(); - self.evals.clear(); - self.clock_times.clear(); - self.castling_rights.clear(); - - // Determine castling mode from Variant header (case-insensitive) - let castling_mode = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) - .and_then(|(_, v)| { - let v_lower = v.to_lowercase(); - if v_lower == "chess960" { - Some(CastlingMode::Chess960) - } else { - None - } - }) - .unwrap_or(CastlingMode::Standard); - - // Try to parse FEN from headers, fall back to default position - let fen_header = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) - .map(|(_, v)| v.as_str()); - - if let Some(fen_str) = fen_header { - match fen_str.parse::() { - Ok(fen) => match fen.into_position(castling_mode) { - Ok(pos) => self.pos = pos, - Err(e) => { - eprintln!("invalid FEN position: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - }, - Err(e) => { - eprintln!("failed to parse FEN: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - } - } else { - self.pos = Chess::default(); - } - - self.push_castling_bitboards(); - if self.store_legal_moves { - self.push_legal_moves(); - } - ControlFlow::Continue(()) - } - - // Roughly half the time during parsing is spent here in san() - fn san( - &mut self, - _movetext: &mut Self::Movetext, - san_plus: SanPlus, - ) -> ControlFlow { - if self.valid_moves { - // Most of the function time is spent calculating to_move() - match san_plus.san.to_move(&self.pos) { - Ok(m) => { - self.pos.play_unchecked(m); - if self.store_legal_moves { - self.push_legal_moves(); - } - let uci_move_obj = UciMove::from_standard(m); - - match uci_move_obj { - UciMove::Normal { - from, - to, - promotion: promo_opt, - } => { - let py_uci_move = PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }; - self.moves.push(py_uci_move); - self.push_castling_bitboards(); - - // Push placeholders to keep vectors in sync - self.comments.push(None); - self.evals.push(None); - self.clock_times.push(None); - } - _ => { - // This case handles UciMove::Put and UciMove::Null, - // which are not expected from standard PGN moves - // that PyUciMove is designed to represent. - eprintln!( - "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", - uci_move_obj - ); - self.valid_moves = false; - } - } - } - Err(err) => { - eprintln!("error in game: {} {}", err, san_plus); - self.valid_moves = false; - } - } - } - ControlFlow::Continue(()) - } - - fn comment( - &mut self, - _movetext: &mut Self::Movetext, - _comment: RawComment<'_>, - ) -> ControlFlow { - match parse_comments(_comment.as_bytes()) { - Ok((remaining_input, parsed_comments)) => { - if !remaining_input.is_empty() { - eprintln!("Unparsed remaining input: {:?}", remaining_input); - return ControlFlow::Continue(()); - } - - let mut move_comments = String::new(); - - for content in parsed_comments { - match content { - CommentContent::Text(text) => { - if !text.trim().is_empty() { - if !move_comments.is_empty() { - move_comments.push(' '); - } - move_comments.push_str(&text); - } - } - CommentContent::Tag(tag_content) => match tag_content { - ParsedTag::Eval(eval_value) => { - if let Some(last_eval) = self.evals.last_mut() { - *last_eval = Some(eval_value); - } - } - ParsedTag::Mate(mate_value) => { - if !move_comments.is_empty() && !move_comments.ends_with(' ') { - move_comments.push(' '); - } - move_comments.push_str(&format!("[Mate {}]", mate_value)); - } - ParsedTag::ClkTime { - hours, - minutes, - seconds, - } => { - if let Some(last_clk) = self.clock_times.last_mut() { - *last_clk = Some((hours, minutes, seconds)); - } - } - }, - } - } - - if let Some(last_comment) = self.comments.last_mut() { - *last_comment = Some(move_comments); - } - } - Err(e) => { - eprintln!("Error parsing comment: {:?}", e); - } - } - ControlFlow::Continue(()) - } - - fn begin_variation( - &mut self, - _movetext: &mut Self::Movetext, - ) -> ControlFlow { - ControlFlow::Continue(Skip(true)) // stay in the mainline - } - - fn outcome( - &mut self, - _movetext: &mut Self::Movetext, - _outcome: Outcome, - ) -> ControlFlow { - self.outcome = Some(match _outcome { - Outcome::Known(known) => match known { - KnownOutcome::Decisive { winner } => format!("{:?}", winner), - KnownOutcome::Draw => "Draw".to_string(), - }, - Outcome::Unknown => "Unknown".to_string(), - }); - self.update_position_status(); - ControlFlow::Continue(()) - } - - fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { - self.valid_moves - } -} - -// --- Native Rust versions (no PyResult) --- -pub fn parse_single_game_native( - pgn: &str, - store_legal_moves: bool, -) -> Result { - let mut reader = Reader::new(Cursor::new(pgn)); - let mut extractor = MoveExtractor::new(store_legal_moves); - match reader.read_game(&mut extractor) { - Ok(Some(_)) => Ok(extractor), - Ok(None) => Err("No game found in PGN".to_string()), - Err(err) => Err(format!("Parsing error: {}", err)), - } -} - -pub fn parse_multiple_games_native( - pgns: &Vec, - num_threads: Option, - store_legal_moves: bool, -) -> Result, String> { - let num_threads = num_threads.unwrap_or_else(num_cpus::get); - - // Build a custom Rayon thread pool with the desired number of threads - let thread_pool = ThreadPoolBuilder::new() - .num_threads(num_threads) - .build() - .expect("Failed to build Rayon thread pool"); - - thread_pool.install(|| { - pgns.par_iter() - .map(|pgn| parse_single_game_native(pgn, store_legal_moves)) - .collect() - }) -} - -fn _parse_game_moves_from_arrow_chunks_native( - pgn_chunked_array: &PyChunkedArray, +mod python_bindings; +mod visitor; + +use python_bindings::{ChunkData, ParsedGames, ParsedGamesIter, PyChunkView, PyGameView}; +pub use visitor::{Buffers, ParseConfig, parse_game_to_buffers}; + +/// Parse games from Arrow chunked array into a chunked ParsedGames container. +/// +/// This implementation uses explicit chunking with a fixed number of chunks +/// (num_chunks = num_threads * chunk_multiplier) to avoid the allocation storm +/// caused by Rayon's dynamic work-stealing with fold_with. +/// +/// Each chunk gets exactly one Buffers instance. Instead of merging all +/// chunks into a single buffer (which was memory-bandwidth-bound), we keep +/// the per-thread buffers and provide virtual indexing across them. +#[pyfunction] +#[pyo3(signature = (pgn_chunked_array, num_threads=None, chunk_multiplier=None, store_comments=false, store_legal_moves=false))] +fn parse_games( + py: Python<'_>, + pgn_chunked_array: PyChunkedArray, num_threads: Option, + chunk_multiplier: Option, + store_comments: bool, store_legal_moves: bool, -) -> Result, String> { +) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; let num_threads = num_threads.unwrap_or_else(num_cpus::get); - let thread_pool = ThreadPoolBuilder::new() - .num_threads(num_threads) - .build() - .map_err(|e| format!("Failed to build Rayon thread pool: {}", e))?; + // Default multiplier of 1 means exactly num_threads chunks (one per thread). + // Higher values (e.g., 4) create more chunks for better load balancing + // at the cost of slightly more complex indexing. + let chunk_multiplier = chunk_multiplier.unwrap_or(1); + // Extract PGN strings from Arrow chunks let mut num_elements = 0; for chunk in pgn_chunked_array.chunks() { num_elements += chunk.len(); } + let mut pgn_str_slices: Vec<&str> = Vec::with_capacity(num_elements); for chunk in pgn_chunked_array.chunks() { if let Some(string_array) = chunk.as_any().downcast_ref::() { @@ -539,53 +63,320 @@ fn _parse_game_moves_from_arrow_chunks_native( } } } else { - return Err(format!( + return Err(PyErr::new::(format!( "Unsupported array type in ChunkedArray: {:?}", chunk.data_type() - )); + ))); } } - thread_pool.install(|| { + let n_games = pgn_str_slices.len(); + if n_games == 0 { + let empty_chunk = buffers_to_chunk_data(py, Buffers::default())?; + return build_parsed_games(py, vec![empty_chunk]); + } + + // Build thread pool + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build thread pool: {}", + e + )) + })?; + + // Calculate chunk size for explicit chunking. + // num_chunks = num_threads * chunk_multiplier (e.g., 16 threads * 1 = 16 chunks) + let num_chunks = num_threads * chunk_multiplier; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; // ceiling division + let chunk_size = chunk_size.max(1); // ensure at least 1 game per chunk + + // Estimate capacity per chunk + let games_per_chunk = chunk_size; + let moves_per_game = 70; + + // Parse in parallel using par_chunks for explicit, fixed-size chunking. + // This creates exactly ceil(n_games / chunk_size) Buffers instances, + // avoiding the allocation storm from Rayon's dynamic work-stealing. + let chunk_results: Vec = thread_pool.install(|| { pgn_str_slices - .par_iter() - .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves)) - .collect::, String>>() + .par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); + for &pgn in chunk { + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); + } + buffers + }) + .collect() + }); + + // Convert each Buffers to ChunkData (numpy arrays) — no merge needed + let chunk_data_vec: Vec = chunk_results + .into_iter() + .map(|buf| buffers_to_chunk_data(py, buf)) + .collect::>>()?; + + build_parsed_games(py, chunk_data_vec) +} + +/// Convert a single Buffers into a ChunkData with NumPy arrays. +fn buffers_to_chunk_data(py: Python<'_>, buffers: Buffers) -> PyResult { + let n_games = buffers.num_games(); + let total_positions = buffers.total_positions(); + let total_moves = buffers.total_moves(); + + // Compute all CSR offsets BEFORE any from_vec calls consume buffer fields + let move_offsets_vec = buffers.compute_move_offsets(); + let position_offsets_vec = buffers.compute_position_offsets(); + let has_legal_moves = !buffers.legal_move_counts.is_empty(); + let legal_move_offsets_vec = if has_legal_moves { + buffers.compute_legal_move_offsets() + } else { + Vec::new() + }; + let total_legal_moves_count = buffers.total_legal_moves(); + + // Boards: reshape from flat to (N_positions, 8, 8) + let boards_array = PyArray1::from_vec(py, buffers.boards); + let boards_reshaped = boards_array + .reshape([total_positions, 8, 8]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Castling: reshape from flat to (N_positions, 4) + let castling_array = PyArray1::from_vec(py, buffers.castling); + let castling_reshaped = castling_array + .reshape([total_positions, 4]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + // 1D arrays + let en_passant_array = PyArray1::from_vec(py, buffers.en_passant); + let halfmove_clock_array = PyArray1::from_vec(py, buffers.halfmove_clock); + let turn_array = PyArray1::from_vec(py, buffers.turn); + + let from_squares_array = PyArray1::from_vec(py, buffers.from_squares); + let to_squares_array = PyArray1::from_vec(py, buffers.to_squares); + let promotions_array = PyArray1::from_vec(py, buffers.promotions); + let clocks_array = PyArray1::from_vec(py, buffers.clocks); + let evals_array = PyArray1::from_vec(py, buffers.evals); + + let move_offsets_array = PyArray1::from_vec(py, move_offsets_vec); + let position_offsets_array = PyArray1::from_vec(py, position_offsets_vec); + + let is_checkmate_array = PyArray1::from_vec(py, buffers.is_checkmate); + let is_stalemate_array = PyArray1::from_vec(py, buffers.is_stalemate); + + let is_insufficient_array = PyArray1::from_vec(py, buffers.is_insufficient); + let is_insufficient_reshaped = if n_games > 0 { + is_insufficient_array + .reshape([n_games, 2]) + .map_err(|e| PyErr::new::(e.to_string()))? + } else { + is_insufficient_array + .reshape([0, 2]) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let legal_move_count_array = PyArray1::from_vec(py, buffers.legal_move_count); + let valid_array = PyArray1::from_vec(py, buffers.valid); + + let legal_move_from_squares_array = PyArray1::from_vec(py, buffers.legal_move_from_squares); + let legal_move_to_squares_array = PyArray1::from_vec(py, buffers.legal_move_to_squares); + let legal_move_promotions_array = PyArray1::from_vec(py, buffers.legal_move_promotions); + let legal_move_offsets_array = PyArray1::from_vec(py, legal_move_offsets_vec); + + Ok(ChunkData { + boards: boards_reshaped.unbind().into_any(), + castling: castling_reshaped.unbind().into_any(), + en_passant: en_passant_array.unbind().into_any(), + halfmove_clock: halfmove_clock_array.unbind().into_any(), + turn: turn_array.unbind().into_any(), + from_squares: from_squares_array.unbind().into_any(), + to_squares: to_squares_array.unbind().into_any(), + promotions: promotions_array.unbind().into_any(), + clocks: clocks_array.unbind().into_any(), + evals: evals_array.unbind().into_any(), + move_offsets: move_offsets_array.unbind().into_any(), + position_offsets: position_offsets_array.unbind().into_any(), + is_checkmate: is_checkmate_array.unbind().into_any(), + is_stalemate: is_stalemate_array.unbind().into_any(), + is_insufficient: is_insufficient_reshaped.unbind().into_any(), + legal_move_count: legal_move_count_array.unbind().into_any(), + valid: valid_array.unbind().into_any(), + headers: buffers.headers, + outcome: buffers.outcome, + comments: buffers.comments, + legal_move_from_squares: legal_move_from_squares_array.unbind().into_any(), + legal_move_to_squares: legal_move_to_squares_array.unbind().into_any(), + legal_move_promotions: legal_move_promotions_array.unbind().into_any(), + legal_move_offsets: legal_move_offsets_array.unbind().into_any(), + num_games: n_games, + num_moves: total_moves, + num_positions: total_positions, + num_legal_moves: total_legal_moves_count, }) } -// --- Python-facing wrappers (PyResult) --- -// TODO check if I can call py.allow_threads and release GIL -// see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ -#[pyfunction] -#[pyo3(signature = (pgn, store_legal_moves = false))] -/// Parses a single PGN game string. -fn parse_game(pgn: &str, store_legal_moves: bool) -> PyResult { - parse_single_game_native(pgn, store_legal_moves) - .map_err(pyo3::exceptions::PyValueError::new_err) +/// Build a ParsedGames from a Vec of ChunkData. +/// +/// Computes prefix-sum boundary arrays and global CSR offsets for +/// position_to_game / move_to_game. +fn build_parsed_games(py: Python<'_>, chunks: Vec) -> PyResult { + let mut total_games: usize = 0; + let mut total_moves: usize = 0; + let mut total_positions: usize = 0; + + // Build boundary arrays (prefix sums) + let mut game_boundaries = Vec::with_capacity(chunks.len() + 1); + let mut move_boundaries = Vec::with_capacity(chunks.len() + 1); + let mut position_boundaries = Vec::with_capacity(chunks.len() + 1); + + game_boundaries.push(0); + move_boundaries.push(0); + position_boundaries.push(0); + + for chunk in &chunks { + total_games += chunk.num_games; + total_moves += chunk.num_moves; + total_positions += chunk.num_positions; + game_boundaries.push(total_games); + move_boundaries.push(total_moves); + position_boundaries.push(total_positions); + } + + // Build global CSR offsets for position_to_game / move_to_game. + // These are the per-chunk local offsets shifted by the chunk's base offset, + // concatenated into a single array. Length = total_games + 1. + let mut global_move_offsets_vec: Vec = Vec::with_capacity(total_games + 1); + let mut global_position_offsets_vec: Vec = Vec::with_capacity(total_games + 1); + + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let move_base = move_boundaries[chunk_idx] as u32; + let pos_base = position_boundaries[chunk_idx] as u32; + + // Read the chunk's local offsets + let local_move_offsets = chunk.move_offsets.bind(py); + let local_move_offsets: &Bound<'_, PyArray1> = local_move_offsets.cast()?; + let local_move_ro = local_move_offsets.readonly(); + let local_move_slice = local_move_ro.as_slice()?; + + let local_pos_offsets = chunk.position_offsets.bind(py); + let local_pos_offsets: &Bound<'_, PyArray1> = local_pos_offsets.cast()?; + let local_pos_ro = local_pos_offsets.readonly(); + let local_pos_slice = local_pos_ro.as_slice()?; + + // Append all but the last offset (which is the total for this chunk) + // The last chunk's final offset will be added after the loop. + for &offset in &local_move_slice[..local_move_slice.len() - 1] { + global_move_offsets_vec.push(move_base + offset); + } + for &offset in &local_pos_slice[..local_pos_slice.len() - 1] { + global_position_offsets_vec.push(pos_base + offset); + } + } + + // Final sentinel value + global_move_offsets_vec.push(total_moves as u32); + global_position_offsets_vec.push(total_positions as u32); + + let global_move_offsets = PyArray1::from_vec(py, global_move_offsets_vec); + let global_position_offsets = PyArray1::from_vec(py, global_position_offsets_vec); + + Ok(ParsedGames { + chunks, + game_boundaries, + move_boundaries, + position_boundaries, + total_games, + total_moves, + total_positions, + global_move_offsets: global_move_offsets.unbind().into_any(), + global_position_offsets: global_position_offsets.unbind().into_any(), + }) } -/// In parallel, parse a set of games +/// Parse a single PGN game string. +/// +/// Convenience wrapper that creates a single-element Arrow array internally. #[pyfunction] -#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false))] -fn parse_games( - pgns: Vec, - num_threads: Option, +#[pyo3(signature = (pgn, store_comments=false, store_legal_moves=false))] +fn parse_game( + py: Python<'_>, + pgn: &str, + store_comments: bool, store_legal_moves: bool, -) -> PyResult> { - parse_multiple_games_native(&pgns, num_threads, store_legal_moves) - .map_err(pyo3::exceptions::PyValueError::new_err) +) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let chunk = buffers_to_chunk_data(py, buffers)?; + build_parsed_games(py, vec![chunk]) } +/// Parse multiple PGN game strings in parallel. +/// +/// Convenience wrapper for when you have a list of strings rather than an Arrow array. #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false))] -fn parse_game_moves_arrow_chunked_array( - pgn_chunked_array: PyChunkedArray, +#[pyo3(signature = (pgns, num_threads=None, store_comments=false, store_legal_moves=false))] +fn parse_games_from_strings( + py: Python<'_>, + pgns: Vec, num_threads: Option, + store_comments: bool, store_legal_moves: bool, -) -> PyResult> { - _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads, store_legal_moves) - .map_err(pyo3::exceptions::PyValueError::new_err) +) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; + let num_threads = num_threads.unwrap_or_else(num_cpus::get); + + let n_games = pgns.len(); + if n_games == 0 { + let empty_chunk = buffers_to_chunk_data(py, Buffers::default())?; + return build_parsed_games(py, vec![empty_chunk]); + } + + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build thread pool: {}", + e + )) + })?; + + let num_chunks = num_threads; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; + let chunk_size = chunk_size.max(1); + let games_per_chunk = chunk_size; + let moves_per_game = 70; + + let chunk_results: Vec = thread_pool.install(|| { + pgns.par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); + for pgn in chunk { + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); + } + buffers + }) + .collect() + }); + + let chunk_data_vec: Vec = chunk_results + .into_iter() + .map(|buf| buffers_to_chunk_data(py, buf)) + .collect::>>()?; + + build_parsed_games(py, chunk_data_vec) } /// Parser for chess PGN notation @@ -593,184 +384,10 @@ fn parse_game_moves_arrow_chunked_array( fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_game, m)?)?; m.add_function(wrap_pyfunction!(parse_games, m)?)?; - m.add_function(wrap_pyfunction!(parse_game_moves_arrow_chunked_array, m)?)?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_function(wrap_pyfunction!(parse_games_from_strings, m)?)?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } - -#[cfg(test)] -mod pyucimove_tests { - use super::*; - use shakmaty::{Role, Square}; - - #[test] - fn test_py_uci_move_no_promotion() { - let uci_move = PyUciMove::new(Square::E2 as u8, Square::E4 as u8, None); - assert_eq!(uci_move.from_square, Square::E2 as u8); - assert_eq!(uci_move.to_square, Square::E4 as u8); - assert_eq!(uci_move.promotion, None); - assert_eq!(uci_move.get_from_square_name(), "e2"); - assert_eq!(uci_move.get_to_square_name(), "e4"); - assert_eq!(uci_move.get_promotion_name(), None); - assert_eq!(uci_move.__str__(), "e2e4"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e2, to_square=e4, promotion=None)" - ); - } - - #[test] - fn test_py_uci_move_with_queen_promotion() { - let uci_move = PyUciMove::new(Square::E7 as u8, Square::E8 as u8, Some(Role::Queen as u8)); - assert_eq!(uci_move.from_square, Square::E7 as u8); - assert_eq!(uci_move.to_square, Square::E8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Queen as u8)); - assert_eq!(uci_move.get_from_square_name(), "e7"); - assert_eq!(uci_move.get_to_square_name(), "e8"); - assert_eq!(uci_move.get_promotion_name(), Some("Queen".to_string())); - assert_eq!(uci_move.__str__(), "e7e8q"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e7, to_square=e8, promotion=Some('q'))" - ); - } - - #[test] - fn test_py_uci_move_with_rook_promotion() { - let uci_move = PyUciMove::new(Square::A7 as u8, Square::A8 as u8, Some(Role::Rook as u8)); - assert_eq!(uci_move.from_square, Square::A7 as u8); - assert_eq!(uci_move.to_square, Square::A8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Rook as u8)); - assert_eq!(uci_move.get_from_square_name(), "a7"); - assert_eq!(uci_move.get_to_square_name(), "a8"); - assert_eq!(uci_move.get_promotion_name(), Some("Rook".to_string())); - assert_eq!(uci_move.__str__(), "a7a8r"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=a7, to_square=a8, promotion=Some('r'))" - ); - } - - #[test] - fn test_py_uci_move_invalid_promotion_val() { - // Test with a u8 value that doesn't correspond to a valid Role - let uci_move = PyUciMove::new(Square::B7 as u8, Square::B8 as u8, Some(99)); // 99 is not a valid Role - assert_eq!(uci_move.from_square, Square::B7 as u8); - assert_eq!(uci_move.to_square, Square::B8 as u8); - assert_eq!(uci_move.promotion, Some(99)); - assert_eq!(uci_move.get_from_square_name(), "b7"); - assert_eq!(uci_move.get_to_square_name(), "b8"); - assert_eq!(uci_move.get_promotion_name(), None); // Should be None as 99 is invalid - assert_eq!(uci_move.__str__(), "b7b8"); // Should produce no promotion char - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=b7, to_square=b8, promotion=Some(InvalidRole(99)))" - ); - } - - #[test] - fn test_parse_game_without_headers() { - let pgn = "1. Nf3 d5 2. e4 c5 3. exd5 e5 4. dxe6 0-1"; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert_eq!(extractor.moves.len(), 7); - assert_eq!(extractor.outcome, Some("Black".to_string())); - } - - #[test] - fn test_parse_game_with_standard_fen() { - // A game starting from a mid-game position - let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] - -3. Bb5 a6 4. Ba4 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!(extractor.valid_moves, "Moves should be valid"); - assert_eq!(extractor.moves.len(), 4); - } - - #[test] - fn test_parse_chess960_game() { - // Chess960 game with custom starting position - let pgn = r#"[Variant "chess960"] -[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] - -1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Chess960 moves should be valid with proper FEN" - ); - assert_eq!(extractor.moves.len(), 6); - } - - #[test] - fn test_parse_chess960_variant_case_insensitive() { - // Test that variant detection is case-insensitive - let pgn = r#"[Variant "Chess960"] -[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] - -1. g3 d5 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Should handle Chess960 case variations" - ); - } - - #[test] - fn test_parse_invalid_fen_falls_back() { - // Invalid FEN should fall back to default and mark invalid - let pgn = r#"[FEN "invalid fen string"] - -1. e4 e5 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - !extractor.valid_moves, - "Should mark as invalid when FEN parsing fails" - ); - } - - #[test] - fn test_fen_header_case_insensitive() { - // FEN header key should be case-insensitive - let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] - -3. Bb5 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Should handle lowercase 'fen' header" - ); - } - - #[test] - fn test_parse_game_with_custom_fen_no_variant() { - // A standard chess game starting from a mid-game position (no Variant header) - // Position after 1.e4 e5 2.Nf3 Nc6 3.Bb5 (Ruy Lopez) - let pgn = r#"[Event "Test Game"] - [FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] - - 3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; - let result = parse_single_game_native(pgn, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Standard game with custom FEN should be valid" - ); - assert_eq!(extractor.moves.len(), 5); // a6, Ba4, Nf6, O-O, Be7 - } -} diff --git a/src/python_bindings.rs b/src/python_bindings.rs new file mode 100644 index 0000000..60d1d93 --- /dev/null +++ b/src/python_bindings.rs @@ -0,0 +1,973 @@ +use numpy::{PyArray1, PyArray2, PyArrayMethods}; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyList, PySlice}; +use std::collections::HashMap; + +/// Internal per-chunk data. Not exposed to Python directly. +/// Each chunk corresponds to one thread's output during parallel parsing. +pub struct ChunkData { + // Per-position numpy arrays + pub boards: Py, // (N_positions, 8, 8) u8 + pub castling: Py, // (N_positions, 4) bool + pub en_passant: Py, // (N_positions,) i8 + pub halfmove_clock: Py, // (N_positions,) u8 + pub turn: Py, // (N_positions,) bool + + // Per-move numpy arrays + pub from_squares: Py, // (N_moves,) u8 + pub to_squares: Py, // (N_moves,) u8 + pub promotions: Py, // (N_moves,) i8 + pub clocks: Py, // (N_moves,) f32 + pub evals: Py, // (N_moves,) f32 + + // Per-game arrays (CSR offsets are local to this chunk) + pub move_offsets: Py, // (N_games + 1,) u32 + pub position_offsets: Py, // (N_games + 1,) u32 + pub is_checkmate: Py, // (N_games,) bool + pub is_stalemate: Py, // (N_games,) bool + pub is_insufficient: Py, // (N_games, 2) bool + pub legal_move_count: Py, // (N_games,) u16 + pub valid: Py, // (N_games,) bool + pub headers: Vec>, + pub outcome: Vec>, // Per-game: "White", "Black", "Draw", "Unknown", or None + + // Optional: raw text comments (per-move), only populated when store_comments=true + pub comments: Vec>, + + // Optional: legal moves at each position (CSR arrays + offsets) + pub legal_move_from_squares: Py, // (N_legal_moves,) u8 + pub legal_move_to_squares: Py, // (N_legal_moves,) u8 + pub legal_move_promotions: Py, // (N_legal_moves,) i8 + pub legal_move_offsets: Py, // (N_positions + 1,) u32 + + // Metadata + pub num_games: usize, + pub num_moves: usize, + pub num_positions: usize, + pub num_legal_moves: usize, +} + +/// Chunked container for parsed chess games, optimized for ML training. +/// +/// Internally stores data in multiple chunks (one per parsing thread) to +/// avoid the cost of merging. Per-game access is O(log(num_chunks)) via +/// binary search on precomputed boundaries. +#[pyclass] +pub struct ParsedGames { + pub chunks: Vec, + + // Global prefix sums for O(1) chunk lookup. + // game_boundaries[i] = total games in chunks 0..i + // So game_boundaries = [0, chunk0.num_games, chunk0+chunk1, ..., total_games] + pub game_boundaries: Vec, + pub move_boundaries: Vec, + pub position_boundaries: Vec, + + pub total_games: usize, + pub total_moves: usize, + pub total_positions: usize, + + // Global offsets for position_to_game / move_to_game (precomputed) + pub global_move_offsets: Py, + pub global_position_offsets: Py, +} + +impl ParsedGames { + /// Locate which chunk a global game index belongs to. + /// Returns (chunk_index, local_game_index). + fn locate_game(&self, global_idx: usize) -> (usize, usize) { + // Binary search: find the last boundary <= global_idx + // game_boundaries = [0, n0, n0+n1, ...], length = num_chunks + 1 + let chunk_idx = match self.game_boundaries.binary_search(&global_idx) { + Ok(i) => { + // Exact match. If it's the last boundary, back up one. + if i >= self.chunks.len() { + self.chunks.len() - 1 + } else { + i + } + } + Err(i) => i - 1, // insertion point - 1 + }; + let local_idx = global_idx - self.game_boundaries[chunk_idx]; + (chunk_idx, local_idx) + } +} + +#[pymethods] +impl ParsedGames { + /// Number of games in the result. + #[getter] + fn num_games(&self) -> usize { + self.total_games + } + + /// Total number of moves across all games. + #[getter] + fn num_moves(&self) -> usize { + self.total_moves + } + + /// Total number of board positions recorded. + #[getter] + fn num_positions(&self) -> usize { + self.total_positions + } + + /// Number of internal chunks. + #[getter] + fn num_chunks(&self) -> usize { + self.chunks.len() + } + + fn __len__(&self) -> usize { + self.total_games + } + + fn __getitem__(slf: Py, py: Python<'_>, idx: &Bound<'_, PyAny>) -> PyResult> { + let n_games = slf.borrow(py).total_games; + + // Handle integer index + if let Ok(mut i) = idx.extract::() { + if i < 0 { + i += n_games as isize; + } + if i < 0 || i >= n_games as isize { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Game index {} out of range [0, {})", + i, n_games + ))); + } + let game_view = PyGameView::new(py, slf.clone_ref(py), i as usize)?; + return Ok(Py::new(py, game_view)?.into_any()); + } + + // Handle slice — returns list of PyGameView + if let Ok(slice) = idx.cast::() { + let indices = slice.indices(n_games as isize)?; + let start = indices.start as usize; + let stop = indices.stop as usize; + let step = indices.step as usize; + + let mut views: Vec> = Vec::new(); + let mut i = start; + while i < stop { + let game_view = PyGameView::new(py, slf.clone_ref(py), i)?; + views.push(Py::new(py, game_view)?); + i += step; + } + return Ok(PyList::new(py, views)?.into_any().unbind()); + } + + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Invalid index type: expected int or slice, got {}", + idx.get_type().name()? + ))) + } + + fn __iter__(slf: Py, py: Python<'_>) -> PyResult { + let total = slf.borrow(py).total_games; + Ok(ParsedGamesIter { + data: slf, + index: 0, + length: total, + }) + } + + /// Escape hatch: access raw per-chunk data. + /// + /// Returns a list of chunk view objects, each exposing numpy arrays + /// for that chunk's data. Use this for advanced/custom access patterns. + #[getter] + fn chunks(slf: Py, py: Python<'_>) -> PyResult> { + let n_chunks = slf.borrow(py).chunks.len(); + let mut views: Vec> = Vec::with_capacity(n_chunks); + for i in 0..n_chunks { + views.push(Py::new( + py, + PyChunkView { + parent: slf.clone_ref(py), + chunk_idx: i, + }, + )?); + } + Ok(PyList::new(py, views)?.into_any().unbind()) + } + + /// Map position indices to game indices. + /// + /// Useful after shuffling/sampling positions to look up game metadata. + /// + /// Args: + /// position_indices: Array of indices into the global position space. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). + /// + /// Returns: + /// Array of game indices (same shape as input) + fn position_to_game<'py>( + &self, + py: Python<'py>, + position_indices: &Bound<'py, PyAny>, + ) -> PyResult>> { + let offsets = self.global_position_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + + let numpy = py.import("numpy")?; + + let int64_dtype = numpy.getattr("int64")?; + let position_indices = numpy + .call_method1("asarray", (position_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + + let len = offsets.len()?; + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + position_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } + + /// Map move indices to game indices. + /// + /// Args: + /// move_indices: Array of indices into the global move space. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). + /// + /// Returns: + /// Array of game indices (same shape as input) + fn move_to_game<'py>( + &self, + py: Python<'py>, + move_indices: &Bound<'py, PyAny>, + ) -> PyResult>> { + let offsets = self.global_move_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + + let numpy = py.import("numpy")?; + + let int64_dtype = numpy.getattr("int64")?; + let move_indices = numpy + .call_method1("asarray", (move_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + + let len = offsets.len()?; + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + move_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } +} + +/// Escape hatch: view into a single chunk's raw numpy arrays. +/// +/// Access via `parsed_games.chunks[i]`. Each chunk corresponds to one +/// parsing thread's output. Use this for advanced access patterns like +/// manual concatenation or custom batching. +#[pyclass] +pub struct PyChunkView { + parent: Py, + chunk_idx: usize, +} + +#[pymethods] +impl PyChunkView { + #[getter] + fn num_games(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_games + } + + #[getter] + fn num_moves(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_moves + } + + #[getter] + fn num_positions(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_positions + } + + #[getter] + fn boards(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .boards + .clone_ref(py) + } + + #[getter] + fn castling(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .castling + .clone_ref(py) + } + + #[getter] + fn en_passant(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .en_passant + .clone_ref(py) + } + + #[getter] + fn halfmove_clock(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .halfmove_clock + .clone_ref(py) + } + + #[getter] + fn turn(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .turn + .clone_ref(py) + } + + #[getter] + fn from_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .from_squares + .clone_ref(py) + } + + #[getter] + fn to_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .to_squares + .clone_ref(py) + } + + #[getter] + fn promotions(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .promotions + .clone_ref(py) + } + + #[getter] + fn clocks(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .clocks + .clone_ref(py) + } + + #[getter] + fn evals(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .evals + .clone_ref(py) + } + + #[getter] + fn move_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .move_offsets + .clone_ref(py) + } + + #[getter] + fn position_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .position_offsets + .clone_ref(py) + } + + #[getter] + fn is_checkmate(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_checkmate + .clone_ref(py) + } + + #[getter] + fn is_stalemate(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_stalemate + .clone_ref(py) + } + + #[getter] + fn is_insufficient(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_insufficient + .clone_ref(py) + } + + #[getter] + fn legal_move_count(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_count + .clone_ref(py) + } + + #[getter] + fn valid(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .valid + .clone_ref(py) + } + + #[getter] + fn headers(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .headers + .clone() + } + + #[getter] + fn outcome(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .outcome + .clone() + } + + /// Raw text comments per move (only populated when store_comments=true). + #[getter] + fn comments(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .comments + .clone() + } + + /// Legal move from-squares for all positions in this chunk. + #[getter] + fn legal_move_from_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_from_squares + .clone_ref(py) + } + + /// Legal move to-squares for all positions in this chunk. + #[getter] + fn legal_move_to_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_to_squares + .clone_ref(py) + } + + /// Legal move promotions for all positions in this chunk. + #[getter] + fn legal_move_promotions(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_promotions + .clone_ref(py) + } + + /// CSR offsets for legal moves (per-position). Length = num_positions + 1. + #[getter] + fn legal_move_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_offsets + .clone_ref(py) + } + + fn __repr__(&self, py: Python<'_>) -> String { + let borrowed = self.parent.borrow(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + format!( + "", + self.chunk_idx, chunk.num_games, chunk.num_moves, chunk.num_positions + ) + } +} + +/// Iterator over games in a ParsedGames result. +#[pyclass] +pub struct ParsedGamesIter { + data: Py, + index: usize, + length: usize, +} + +#[pymethods] +impl ParsedGamesIter { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>, py: Python<'_>) -> PyResult> { + if slf.index >= slf.length { + return Ok(None); + } + let game_view = PyGameView::new(py, slf.data.clone_ref(py), slf.index)?; + slf.index += 1; + Ok(Some(game_view)) + } +} + +/// Zero-copy view into a single game's data within a ParsedGames result. +/// +/// Board indexing note: Boards use square indexing (a1=0, h8=63). +/// To convert to rank/file: +/// rank = square // 8 +/// file = square % 8 +#[pyclass] +pub struct PyGameView { + data: Py, + /// Index of the chunk this game lives in. + chunk_idx: usize, + /// Local game index within the chunk. + local_idx: usize, + /// Move range within the chunk's move arrays. + move_start: usize, + move_end: usize, + /// Position range within the chunk's position arrays. + pos_start: usize, + pos_end: usize, +} + +impl PyGameView { + /// Create a new game view for global game index `global_idx`. + pub fn new(py: Python<'_>, data: Py, global_idx: usize) -> PyResult { + let borrowed = data.borrow(py); + + let (chunk_idx, local_idx) = borrowed.locate_game(global_idx); + let chunk = &borrowed.chunks[chunk_idx]; + + // Read this chunk's local CSR offsets + let move_offsets = chunk.move_offsets.bind(py); + let move_offsets: &Bound<'_, PyArray1> = move_offsets.cast()?; + let pos_offsets = chunk.position_offsets.bind(py); + let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; + + let move_offsets_ro = move_offsets.readonly(); + let move_offsets_slice = move_offsets_ro.as_slice()?; + let pos_offsets_ro = pos_offsets.readonly(); + let pos_offsets_slice = pos_offsets_ro.as_slice()?; + + let move_start = move_offsets_slice + .get(local_idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let move_end = move_offsets_slice + .get(local_idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_start = pos_offsets_slice + .get(local_idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_end = pos_offsets_slice + .get(local_idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + + drop(borrowed); + + Ok(Self { + data, + chunk_idx, + local_idx, + move_start: *move_start as usize, + move_end: *move_end as usize, + pos_start: *pos_start as usize, + pos_end: *pos_end as usize, + }) + } +} + +#[pymethods] +impl PyGameView { + /// Number of moves in this game. + fn __len__(&self) -> usize { + self.move_end - self.move_start + } + + /// Number of positions recorded for this game. + #[getter] + fn num_positions(&self) -> usize { + self.pos_end - self.pos_start + } + + // === Board state views === + + /// Board positions, shape (num_positions, 8, 8). + #[getter] + fn boards<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = boards.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Initial board position, shape (8, 8). + #[getter] + fn initial_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_start,))?; + Ok(slice.unbind()) + } + + /// Final board position, shape (8, 8). + #[getter] + fn final_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_end - 1,))?; + Ok(slice.unbind()) + } + + /// Castling rights [K,Q,k,q], shape (num_positions, 4). + #[getter] + fn castling<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].castling.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// En passant file (-1 if none), shape (num_positions,). + #[getter] + fn en_passant<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].en_passant.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Halfmove clock, shape (num_positions,). + #[getter] + fn halfmove_clock<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].halfmove_clock.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Side to move (True=white), shape (num_positions,). + #[getter] + fn turn<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].turn.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Move views === + + /// From squares, shape (num_moves,). + #[getter] + fn from_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].from_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// To squares, shape (num_moves,). + #[getter] + fn to_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].to_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,). + #[getter] + fn promotions<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].promotions.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Clock times in seconds (NaN if missing), shape (num_moves,). + #[getter] + fn clocks<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].clocks.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Engine evals (NaN if missing), shape (num_moves,). + #[getter] + fn evals<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].evals.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Per-game metadata === + + /// Raw PGN headers as dict. + #[getter] + fn headers(&self, py: Python<'_>) -> PyResult> { + let borrowed = self.data.borrow(py); + Ok(borrowed.chunks[self.chunk_idx].headers[self.local_idx].clone()) + } + + /// Final position is checkmate. + #[getter] + fn is_checkmate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].is_checkmate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.local_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Final position is stalemate. + #[getter] + fn is_stalemate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].is_stalemate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.local_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Insufficient material (white, black). + #[getter] + fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].is_insufficient.bind(py); + let arr: &Bound<'_, PyArray2> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + let base = self.local_idx * 2; + let white = slice + .get(base) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let black = slice + .get(base + 1) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + Ok((white, black)) + } + + /// Legal move count in final position. + #[getter] + fn legal_move_count(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].legal_move_count.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.local_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Whether game parsed successfully. + #[getter] + fn is_valid(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.chunks[self.chunk_idx].valid.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.local_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Whether the game is over (checkmate, stalemate, or both sides have insufficient material). + #[getter] + fn is_game_over(&self, py: Python<'_>) -> PyResult { + let checkmate = self.is_checkmate(py)?; + let stalemate = self.is_stalemate(py)?; + let (insuf_white, insuf_black) = self.is_insufficient(py)?; + Ok(checkmate || stalemate || (insuf_white && insuf_black)) + } + + /// Game outcome from movetext: "White", "Black", "Draw", "Unknown", or None. + #[getter] + fn outcome(&self, py: Python<'_>) -> PyResult> { + let borrowed = self.data.borrow(py); + Ok(borrowed.chunks[self.chunk_idx].outcome[self.local_idx].clone()) + } + + /// Raw text comments per move (only populated when store_comments=true). + /// Returns list[str | None] of length num_moves. + #[getter] + fn comments(&self, py: Python<'_>) -> PyResult>> { + let borrowed = self.data.borrow(py); + let chunk_comments = &borrowed.chunks[self.chunk_idx].comments; + if chunk_comments.is_empty() { + return Ok(Vec::new()); + } + Ok(chunk_comments[self.move_start..self.move_end].to_vec()) + } + + /// Legal moves at each position in this game. + /// Returns list of lists: [[from, to, promotion], ...] per position. + /// Only populated when store_legal_moves=true. + #[getter] + fn legal_moves(&self, py: Python<'_>) -> PyResult>> { + let borrowed = self.data.borrow(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + + // Check if legal moves were stored + if chunk.num_legal_moves == 0 { + return Ok(Vec::new()); + } + + let offsets_arr = chunk.legal_move_offsets.bind(py); + let offsets_arr: &Bound<'_, PyArray1> = offsets_arr.cast()?; + let offsets_ro = offsets_arr.readonly(); + let offsets_slice = offsets_ro.as_slice()?; + + let from_arr = chunk.legal_move_from_squares.bind(py); + let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; + let from_ro = from_arr.readonly(); + let from_slice = from_ro.as_slice()?; + + let to_arr = chunk.legal_move_to_squares.bind(py); + let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; + let to_ro = to_arr.readonly(); + let to_slice = to_ro.as_slice()?; + + let promo_arr = chunk.legal_move_promotions.bind(py); + let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; + let promo_ro = promo_arr.readonly(); + let promo_slice = promo_ro.as_slice()?; + + let mut result = Vec::with_capacity(self.pos_end - self.pos_start); + for pos_idx in self.pos_start..self.pos_end { + let start = offsets_slice[pos_idx] as usize; + let end = offsets_slice[pos_idx + 1] as usize; + let mut moves = Vec::with_capacity(end - start); + for i in start..end { + moves.push((from_slice[i], to_slice[i], promo_slice[i])); + } + result.push(moves); + } + Ok(result) + } + + // === Convenience methods === + + /// Get UCI string for move at index. + fn move_uci(&self, py: Python<'_>, move_idx: usize) -> PyResult { + if move_idx >= self.move_end - self.move_start { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Move index {} out of range [0, {})", + move_idx, + self.move_end - self.move_start + ))); + } + + let borrowed = self.data.borrow(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + let from_arr = chunk.from_squares.bind(py); + let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; + let to_arr = chunk.to_squares.bind(py); + let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; + let promo_arr = chunk.promotions.bind(py); + let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; + + let from_ro = from_arr.readonly(); + let to_ro = to_arr.readonly(); + let promo_ro = promo_arr.readonly(); + + let from_slice = from_ro.as_slice()?; + let to_slice = to_ro.as_slice()?; + let promo_slice = promo_ro.as_slice()?; + + let abs_idx = self.move_start + move_idx; + let from_sq = from_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let to_sq = to_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let promo = promo_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + + let files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']; + let ranks = ['1', '2', '3', '4', '5', '6', '7', '8']; + + let mut uci = format!( + "{}{}{}{}", + files[(from_sq % 8) as usize], + ranks[(from_sq / 8) as usize], + files[(to_sq % 8) as usize], + ranks[(to_sq / 8) as usize] + ); + + if promo >= 0 { + let promo_chars = ['_', '_', 'n', 'b', 'r', 'q']; // 2=N, 3=B, 4=R, 5=Q + if (promo as usize) < promo_chars.len() { + uci.push(promo_chars[promo as usize]); + } + } + + Ok(uci) + } + + /// Get all moves as UCI strings. + fn moves_uci(&self, py: Python<'_>) -> PyResult> { + let n_moves = self.move_end - self.move_start; + let mut result = Vec::with_capacity(n_moves); + for i in 0..n_moves { + result.push(self.move_uci(py, i)?); + } + Ok(result) + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + let headers = self.headers(py)?; + let white = headers.get("White").map(|s| s.as_str()).unwrap_or("?"); + let black = headers.get("Black").map(|s| s.as_str()).unwrap_or("?"); + let n_moves = self.move_end - self.move_start; + let is_valid = self.is_valid(py)?; + Ok(format!( + "", + white, black, n_moves, is_valid + )) + } +} diff --git a/src/test.py b/src/test.py index a4decf2..59de284 100644 --- a/src/test.py +++ b/src/test.py @@ -1,16 +1,404 @@ import unittest +import numpy as np + import rust_pgn_reader_python_binding +from rust_pgn_reader_python_binding import PyGameView # for a typing check + import pyarrow as pa -class TestPgnExtraction(unittest.TestCase): - def run_extractor(self, pgn_string): - extractor = rust_pgn_reader_python_binding.parse_game(pgn_string) - return extractor +class TestParsedGames(unittest.TestCase): + def test_basic_structure(self): + """Test basic parsing returns correct structure.""" + pgns = [ + "1. e4 e5 2. Nf3 Nc6 3. Bb5 1-0", + "1. d4 d5 2. c4 e6 0-1", + ] + chunked = pa.chunked_array([pa.array(pgns)]) + # Use 1 thread to get a single chunk for predictable array shapes + result = rust_pgn_reader_python_binding.parse_games(chunked, num_threads=1) + + # Check game count + self.assertEqual(len(result), 2) + self.assertEqual(result.num_games, 2) + self.assertEqual(result.num_moves, 9) + self.assertEqual(result.num_positions, 11) # 9 moves + 2 initial positions + + # Check per-game structure via game views + game0 = result[0] + self.assertEqual(len(game0), 5) # Game 1: 5 half-moves + self.assertEqual(game0.num_positions, 6) + + game1 = result[1] + self.assertEqual(len(game1), 4) # Game 2: 4 half-moves + self.assertEqual(game1.num_positions, 5) + + # With 1 thread, all games are in a single chunk + self.assertEqual(result.num_chunks, 1) + chunk = result.chunks[0] + total_moves = 9 + total_positions = 9 + 2 + self.assertEqual(chunk.boards.shape, (total_positions, 8, 8)) + self.assertEqual(chunk.castling.shape, (total_positions, 4)) + self.assertEqual(chunk.en_passant.shape, (total_positions,)) + self.assertEqual(chunk.from_squares.shape, (total_moves,)) + self.assertEqual(chunk.valid.shape, (2,)) + + def test_initial_board_encoding(self): + """Test initial board state encoding.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + initial = result[0].initial_board # First position + + # Encoding: 0=empty, 1=P, 2=N, 3=B, 4=R, 5=Q, 6=K, +6 for black + # Square indexing: a1=0, b1=1, ..., h1=7, a2=8, ... + + # a1 (index 0) = white rook = 4 + self.assertEqual(initial.flat[0], 4) + # b1 (index 1) = white knight = 2 + self.assertEqual(initial.flat[1], 2) + # e1 (index 4) = white king = 6 + self.assertEqual(initial.flat[4], 6) + # e2 (index 12) = white pawn = 1 + self.assertEqual(initial.flat[12], 1) + # e4 (index 28) = empty = 0 + self.assertEqual(initial.flat[28], 0) + # e7 (index 52) = black pawn = 7 + self.assertEqual(initial.flat[52], 7) + # e8 (index 60) = black king = 12 + self.assertEqual(initial.flat[60], 12) + + def test_board_after_move(self): + """Test board state updates correctly after move.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + # Position 0: initial, Position 1: after e4 + after_e4 = result[0].boards[1] + + # e2 (index 12) should be empty + self.assertEqual(after_e4.flat[12], 0) + # e4 (index 28) should have white pawn + self.assertEqual(after_e4.flat[28], 1) + + def test_en_passant_tracking(self): + """Test en passant square tracking.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + # Initial: no en passant + self.assertEqual(game.en_passant[0], -1) + # After e4: en passant on e-file (file index 4) + self.assertEqual(game.en_passant[1], 4) + + def test_castling_rights(self): + """Test castling rights tracking.""" + # White moves rook, losing kingside castling + pgns = ["1. e4 e5 2. Nf3 Nc6 3. Rg1 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + # Initial: all castling [K, Q, k, q] = [True, True, True, True] + self.assertTrue(all(game.castling[0])) + + # After Rg1 (position 5): white kingside lost + # Castling order: [K, Q, k, q] + self.assertFalse(game.castling[5, 0]) # White K + self.assertTrue(game.castling[5, 1]) # White Q + self.assertTrue(game.castling[5, 2]) # Black k + self.assertTrue(game.castling[5, 3]) # Black q + + def test_turn_tracking(self): + """Test side-to-move tracking.""" + pgns = ["1. e4 e5 2. Nf3 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + # Initial: white to move + self.assertTrue(game.turn[0]) + # After e4: black to move + self.assertFalse(game.turn[1]) + # After e5: white to move + self.assertTrue(game.turn[2]) + + def test_game_view_access(self): + """Test GameView provides correct slices.""" + pgns = ["1. e4 e5 2. Nf3 1-0", "1. d4 d5 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game0 = result[0] + self.assertEqual(len(game0), 3) + self.assertEqual(game0.num_positions, 4) + self.assertEqual(game0.boards.shape, (4, 8, 8)) + self.assertTrue(game0.is_valid) + + game1 = result[1] + self.assertEqual(len(game1), 2) + self.assertEqual(game1.num_positions, 3) + + def test_game_view_move_uci(self): + """Test GameView UCI move conversion.""" + pgns = ["1. e4 e5 2. Nf3 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + self.assertEqual(game.move_uci(0), "e2e4") + self.assertEqual(game.move_uci(1), "e7e5") + self.assertEqual(game.move_uci(2), "g1f3") + + self.assertEqual(game.moves_uci(), ["e2e4", "e7e5", "g1f3"]) + + def test_iteration(self): + """Test iteration over games.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + games = list(result) + self.assertEqual(len(games), 3) + self.assertIsInstance(games[0], PyGameView) + + def test_slicing(self): + """Test slicing returns list of game views.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + sliced = result[1:3] + self.assertEqual(len(sliced), 2) + games = list(sliced) + self.assertEqual(len(games[0]), 1) # d4 game + + def test_position_to_game_mapping(self): + """Test position to game index mapping.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves (3 pos), 1 move (2 pos) + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + # Positions: 0,1,2 (game 0), 3,4 (game 1) + pos_indices = np.array([0, 1, 2, 3, 4]) + game_indices = result.position_to_game(pos_indices) + + np.testing.assert_array_equal(game_indices, [0, 0, 0, 1, 1]) + + def test_move_to_game_mapping(self): + """Test move to game index mapping.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves, 1 move + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + move_indices = np.array([0, 1, 2]) + game_indices = result.move_to_game(move_indices) + + np.testing.assert_array_equal(game_indices, [0, 0, 1]) + + def test_position_to_game_accepts_various_dtypes(self): + """Test position_to_game accepts various integer dtypes.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + for dtype in [np.int32, np.int64, np.uint32, np.uint64]: + pos_indices = np.array([0, 1, 2, 3, 4], dtype=dtype) + game_indices = result.position_to_game(pos_indices) + np.testing.assert_array_equal(game_indices, [0, 0, 0, 1, 1]) + + def test_move_to_game_accepts_various_dtypes(self): + """Test move_to_game accepts various integer dtypes.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + for dtype in [np.int32, np.int64, np.uint32, np.uint64]: + move_indices = np.array([0, 1, 2], dtype=dtype) + game_indices = result.move_to_game(move_indices) + np.testing.assert_array_equal(game_indices, [0, 0, 1]) + + def test_clocks_and_evals(self): + """Test clock and eval parsing.""" + pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... e5 { [%eval 0.19] [%clk 0:00:29] } 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + self.assertAlmostEqual(game.evals[0], 0.17, places=2) + self.assertAlmostEqual(game.evals[1], 0.19, places=2) + self.assertAlmostEqual(game.clocks[0], 30.0, places=1) + self.assertAlmostEqual(game.clocks[1], 29.0, places=1) + + def test_missing_clocks_evals_are_nan(self): + """Test missing clocks/evals are NaN.""" + pgns = ["1. e4 e5 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + self.assertTrue(np.isnan(game.clocks[0])) + self.assertTrue(np.isnan(game.evals[0])) + + def test_headers_preserved(self): + """Test headers are preserved as dicts.""" + pgn = """[White "Player1"] +[Black "Player2"] +[WhiteElo "1500"] + +1. e4 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + self.assertEqual(game.headers["White"], "Player1") + self.assertEqual(game.headers["Black"], "Player2") + self.assertEqual(game.headers["WhiteElo"], "1500") + + def test_invalid_game_flagged(self): + """Test invalid games are flagged but don't break structure.""" + pgns = [ + "1. e4 e5 1-0", # Valid + "1. e4 Qxd7 1-0", # Invalid move + "1. d4 d5 0-1", # Valid + ] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + self.assertEqual(len(result), 3) + self.assertTrue(result[0].is_valid) + self.assertFalse(result[1].is_valid) + self.assertTrue(result[2].is_valid) + + def test_checkmate_detection(self): + """Test checkmate is detected.""" + # Scholar's mate + pgn = "1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + self.assertTrue(game.is_checkmate) + self.assertFalse(game.is_stalemate) + self.assertEqual(game.legal_move_count, 0) + + def test_promotion(self): + """Test promotion encoding.""" + pgn = """[FEN "8/P7/8/8/8/8/8/4K2k w - - 0 1"] + +1. a8=Q 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + game = result[0] + # Promotion to queen = 5 + self.assertEqual(game.promotions[0], 5) + self.assertEqual(game.move_uci(0), "a7a8q") + + def test_num_properties(self): + """Test num_games, num_moves, num_positions properties.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 + 1 = 3 moves, 3 + 2 = 5 positions + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + self.assertEqual(result.num_games, 2) + self.assertEqual(result.num_moves, 3) + self.assertEqual(result.num_positions, 5) + + def test_negative_indexing(self): + """Test negative index access.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + # -1 should be the last game + last_game = result[-1] + self.assertEqual(len(last_game), 1) + + # Check that from_squares for c4 is c2 + # c2 = file 2 (c) + rank 1 (2nd rank) * 8 = 2 + 8 = 10 + self.assertEqual(last_game.from_squares[0], 10) # c2 + + def test_outcome(self): + """Test outcome is parsed from movetext.""" + pgns = ["1. e4 e5 1-0", "1. d4 d5 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + self.assertEqual(result[0].outcome, "White") + self.assertEqual(result[1].outcome, "Black") + self.assertEqual(result[2].outcome, "Draw") + + def test_outcome_without_headers(self): + """Test outcome works for PGNs without Result header.""" + pgn = "1. e4 e5 0-1" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertEqual(result[0].outcome, "Black") + + def test_is_game_over(self): + """Test is_game_over derived property.""" + # Scholar's mate + pgn = "1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertTrue(result[0].is_game_over) + + # Not game over + pgn2 = "1. e4 e5 1-0" + chunked2 = pa.chunked_array([pa.array([pgn2])]) + result2 = rust_pgn_reader_python_binding.parse_games(chunked2) + self.assertFalse(result2[0].is_game_over) + + def test_comments_disabled_by_default(self): + """Test comments are empty by default.""" + pgn = "1. e4 { a comment } e5 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertEqual(result[0].comments, []) + + def test_comments_enabled(self): + """Test comments are stored when enabled.""" + pgn = "1. e4 {asdf} e5 { [%eval 0.19] } 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games( + chunked, store_comments=True + ) + comments = result[0].comments + self.assertEqual(len(comments), 2) + self.assertEqual(comments[0], "asdf") + self.assertIsNotNone(comments[1]) # eval-only comment + + def test_parse_game_string(self): + """Test parse_game convenience function for single string.""" + pgn = "1. e4 e5 2. Nf3 Nc6 1-0" + result = rust_pgn_reader_python_binding.parse_game(pgn) + self.assertEqual(result.num_games, 1) + self.assertEqual(result[0].moves_uci(), ["e2e4", "e7e5", "g1f3", "b8c6"]) + self.assertEqual(result[0].outcome, "White") + + def test_parse_games_from_strings(self): + """Test parse_games_from_strings convenience function.""" + pgns = [ + "1. e4 e5 1-0", + "1. d4 d5 0-1", + ] + result = rust_pgn_reader_python_binding.parse_games_from_strings(pgns) + self.assertEqual(result.num_games, 2) + self.assertTrue(result[0].is_valid) + self.assertTrue(result[1].is_valid) + + def test_long_game_with_comments(self): + """Test a long game with inline comments, verifying all moves.""" + pgn = "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 8. Be3 a6 9. Ba4 Bd7 10. Bxc5 dxc5 11. Bxc6 Bxc6 12. Nxe5 Bb5 13. Re1 Re8 14. c4 Rxe5 15. cxb5 axb5 16. Nc3 b4 17. Nd5 Nxd5 18. exd5 Rxe1+ 19. Qxe1 Qxd5 20. Qe3 b6 21. b3 c6 22. Qe2 b5 23. Rd1 Qd4 24. Qe7 Rxa2 25. Qe8+ Kh7 26. Qxf7 c4 27. Qf5+ Kh8 28. Qf8+ Kh7 29. Qf5+ Kh8 30. Qf8+ Kh7 31. Qf5+ Kh8 32. Qf8+ Kh7 1/2-1/2" - def test_short_pgn(self): - pgn_moves = "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 8. Be3 a6 9. Ba4 Bd7 10. Bxc5 dxc5 11. Bxc6 Bxc6 12. Nxe5 Bb5 13. Re1 Re8 14. c4 Rxe5 15. cxb5 axb5 16. Nc3 b4 17. Nd5 Nxd5 18. exd5 Rxe1+ 19. Qxe1 Qxd5 20. Qe3 b6 21. b3 c6 22. Qe2 b5 23. Rd1 Qd4 24. Qe7 Rxa2 25. Qe8+ Kh7 26. Qxf7 c4 27. Qf5+ Kh8 28. Qf8+ Kh7 29. Qf5+ Kh8 30. Qf8+ Kh7 31. Qf5+ Kh8 32. Qf8+ Kh7 1/2-1/2" - extractor = self.run_extractor(pgn_moves) + result = rust_pgn_reader_python_binding.parse_game(pgn, store_comments=True) + game = result[0] moves_reference = [ "e2e4", @@ -79,34 +467,23 @@ def test_short_pgn(self): "h8h7", ] - comments_reference = [ - "asdf", - None, - None, - None, - None, - None, - "hello", - ] + [None for _ in range(57)] - - valid_reference = True - evals_reference = [None for _ in range(len(moves_reference))] - clock_times_reference = [None for _ in range(len(moves_reference))] - - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_full_pgn(self): - pgn_moves = """ -[Event "Rated Classical game"] + self.assertEqual(game.moves_uci(), moves_reference) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Draw") + self.assertFalse(game.is_checkmate) + self.assertFalse(game.is_stalemate) + self.assertFalse(game.is_game_over) + + # Comments: "asdf" on move 0, "hello" on move 6, rest should be non-None + # but empty (comments are enabled so placeholders exist) + comments = game.comments + self.assertEqual(len(comments), 64) + self.assertIn("asdf", comments[0]) + self.assertIn("hello", comments[6]) + + def test_full_game_with_headers(self): + """Test a full game with many headers, verifying headers and moves.""" + pgn = """[Event "Rated Classical game"] [Site "https://lichess.org/lhy6ehiv"] [White "goerch"] [Black "niltonrosao001"] @@ -123,9 +500,10 @@ def test_full_pgn(self): [Termination "Normal"] 1. d4 Nf6 2. Nf3 c5 3. e3 b6 4. Nc3 e6 5. Bb5 a6 6. Bd3 Bb7 7. O-O b5 8. b3 d5 9. Bb2 Nbd7 10. a4 b4 11. Ne2 Bd6 12. c4 bxc3 13. Bxc3 O-O 14. Ng3 Rc8 15. dxc5 Nxc5 16. Nd4 Nxd3 17. Qxd3 Qb6 18. Rab1 Bb4 19. Nge2 Ne4 20. Rfc1 Nxc3 21. Nxc3 Rc7 22. Na2 Rfc8 23. Rc2 g6 24. Nxb4 Qxb4 25. Rbc1 Rxc2 26. Nxc2 Qc3 27. Qxc3 Rxc3 28. Kf1 d4 29. exd4 Be4 30. Ke1 Rxc2 31. Rxc2 Bxc2 32. Kd2 Bxb3 33. a5 Bd5 0-1 - """ - extractor = self.run_extractor(pgn_moves) + + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] moves_reference = [ "d2d4", @@ -195,90 +573,49 @@ def test_full_pgn(self): "a4a5", "b3d5", ] - comments_reference = [None for _ in range(len(moves_reference))] - valid_reference = True - evals_reference = [None for _ in range(len(moves_reference))] - clock_times_reference = [None for _ in range(len(moves_reference))] - headers_reference = [ - ("Event", "Rated Classical game"), - ("Site", "https://lichess.org/lhy6ehiv"), - ("White", "goerch"), - ("Black", "niltonrosao001"), - ("Result", "0-1"), - ("UTCDate", "2013.06.30"), - ("UTCTime", "22:10:02"), - ("WhiteElo", "1702"), - ("BlackElo", "2011"), - ("WhiteRatingDiff", "-3"), - ("BlackRatingDiff", "+5"), - ("ECO", "A46"), - ("Opening", "Indian Game: Spielmann-Indian"), - ("TimeControl", "600+8"), - ("Termination", "Normal"), - ] - - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - self.assertTrue(extractor.headers == headers_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_full_pgn_annotated(self): - pgn_moves = """ - 1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } - 2. Nf3 { [%eval 0.25] [%clk 0:00:29] } 2... Nc6 { [%eval 0.33] [%clk 0:00:30] } - 3. Bc4 { [%eval -0.13] [%clk 0:00:28] } 3... e6 { [%eval -0.04] [%clk 0:00:30] } - 4. c3 { [%eval -0.4] [%clk 0:00:27] } 4... b5? { [%eval 1.18] [%clk 0:00:30] } - 5. Bb3?! { [%eval 0.21] [%clk 0:00:26] } 5... c4 { [%eval 0.32] [%clk 0:00:29] } - 6. Bc2 { [%eval 0.2] [%clk 0:00:25] } 6... a5 { [%eval 0.6] [%clk 0:00:29] } - 7. d4 { [%eval 0.29] [%clk 0:00:23] } 7... cxd3 { [%eval 0.6] [%clk 0:00:27] } - 8. Qxd3 { [%eval 0.12] [%clk 0:00:22] } 8... Nf6 { [%eval 0.52] [%clk 0:00:26] } - 9. e5 { [%eval 0.39] [%clk 0:00:21] } 9... Nd5 { [%eval 0.45] [%clk 0:00:25] } - 10. Bg5?! { [%eval -0.44] [%clk 0:00:18] } 10... Qc7 { [%eval -0.12] [%clk 0:00:23] } - 11. Nbd2?? { [%eval -3.15] [%clk 0:00:14] } 11... h6 { [%eval -2.99] [%clk 0:00:23] } - 12. Bh4 { [%eval -3.0] [%clk 0:00:11] } 12... Ba6? { [%eval -0.12] [%clk 0:00:23] } - 13. b3?? { [%eval -4.14] [%clk 0:00:02] } 13... Nf4? { [%eval -2.73] [%clk 0:00:21] } 0-1 - """ - extractor = self.run_extractor(pgn_moves) - moves_reference = [ - "e2e4", - "c7c5", - "g1f3", - "b8c6", - "f1c4", - "e7e6", - "c2c3", - "b7b5", - "c4b3", - "c5c4", - "b3c2", - "a7a5", - "d2d4", - "c4d3", - "d1d3", - "g8f6", - "e4e5", - "f6d5", - "c1g5", - "d8c7", - "b1d2", - "h7h6", - "g5h4", - "c8a6", - "b2b3", - "d5f4", - ] - comments_reference = [ - "", - ] * 26 - valid_reference = True + self.assertEqual(game.moves_uci(), moves_reference) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Black") + + # Verify headers + self.assertEqual(game.headers["Event"], "Rated Classical game") + self.assertEqual(game.headers["Site"], "https://lichess.org/lhy6ehiv") + self.assertEqual(game.headers["White"], "goerch") + self.assertEqual(game.headers["Black"], "niltonrosao001") + self.assertEqual(game.headers["Result"], "0-1") + self.assertEqual(game.headers["WhiteElo"], "1702") + self.assertEqual(game.headers["BlackElo"], "2011") + self.assertEqual(game.headers["ECO"], "A46") + self.assertEqual(game.headers["Opening"], "Indian Game: Spielmann-Indian") + self.assertEqual(game.headers["TimeControl"], "600+8") + self.assertEqual(game.headers["Termination"], "Normal") + + def test_annotated_game_all_evals_and_clocks(self): + """Test a fully annotated game with eval and clock on every move.""" + pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } +2. Nf3 { [%eval 0.25] [%clk 0:00:29] } 2... Nc6 { [%eval 0.33] [%clk 0:00:30] } +3. Bc4 { [%eval -0.13] [%clk 0:00:28] } 3... e6 { [%eval -0.04] [%clk 0:00:30] } +4. c3 { [%eval -0.4] [%clk 0:00:27] } 4... b5 { [%eval 1.18] [%clk 0:00:30] } +5. Bb3 { [%eval 0.21] [%clk 0:00:26] } 5... c4 { [%eval 0.32] [%clk 0:00:29] } +6. Bc2 { [%eval 0.2] [%clk 0:00:25] } 6... a5 { [%eval 0.6] [%clk 0:00:29] } +7. d4 { [%eval 0.29] [%clk 0:00:23] } 7... cxd3 { [%eval 0.6] [%clk 0:00:27] } +8. Qxd3 { [%eval 0.12] [%clk 0:00:22] } 8... Nf6 { [%eval 0.52] [%clk 0:00:26] } +9. e5 { [%eval 0.39] [%clk 0:00:21] } 9... Nd5 { [%eval 0.45] [%clk 0:00:25] } +10. Bg5 { [%eval -0.44] [%clk 0:00:18] } 10... Qc7 { [%eval -0.12] [%clk 0:00:23] } +11. Nbd2 { [%eval -3.15] [%clk 0:00:14] } 11... h6 { [%eval -2.99] [%clk 0:00:23] } +12. Bh4 { [%eval -3.0] [%clk 0:00:11] } 12... Ba6 { [%eval -0.12] [%clk 0:00:23] } +13. b3 { [%eval -4.14] [%clk 0:00:02] } 13... Nf4 { [%eval -2.73] [%clk 0:00:21] } 0-1""" + + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] + + # Verify move count + self.assertEqual(len(game), 26) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Black") + + # Verify all 26 evals evals_reference = [ 0.17, 0.19, @@ -307,377 +644,124 @@ def test_full_pgn_annotated(self): -4.14, -2.73, ] - clock_times_reference = [ - (0, 0, 30), - (0, 0, 30), - (0, 0, 29), - (0, 0, 30), - (0, 0, 28), - (0, 0, 30), - (0, 0, 27), - (0, 0, 30), - (0, 0, 26), - (0, 0, 29), - (0, 0, 25), - (0, 0, 29), - (0, 0, 23), - (0, 0, 27), - (0, 0, 22), - (0, 0, 26), - (0, 0, 21), - (0, 0, 25), - (0, 0, 18), - (0, 0, 23), - (0, 0, 14), - (0, 0, 23), - (0, 0, 11), - (0, 0, 23), - (0, 0, 2), - (0, 0, 21), - ] - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_multithreaded(self): - pgns = [ - "1. Nf3 g6 2. b3 Bg7 3. Nc3 e5 4. Bb2 e4 5. Ng1 d6 6. Rb1 a5 7. Nxe4 Bxb2 8. Rxb2 Nf6 9. Nxf6+ Qxf6 10. Rb1 Ra6 11. e3 Rb6 12. d4 a4 13. bxa4 Nc6 14. Rxb6 cxb6 15. Bb5 Bd7 16. Bxc6 Bxc6 17. Nf3 Bxa4 18. O-O O-O 19. Re1 Rc8 20. Re2 d5 21. Ne5 Qf5 22. Qd3 Bxc2 23. Qxf5 Bxf5 24. h3 b5 25. Rb2 f6 26. Ng4 Rc6 27. Nh6+ Kg7 28. Nxf5+ gxf5 29. Rxb5 Rc7 30. Rxd5 Kg6 31. f4 Kh5 32. Rxf5+ Kh4 33. Rxf6 Kg3 34. d5 Rc1# 0-1", - "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O", - ] - extractor = rust_pgn_reader_python_binding.parse_games(pgns) + for i, expected_eval in enumerate(evals_reference): + self.assertAlmostEqual( + game.evals[i], expected_eval, places=2, msg=f"Eval mismatch at move {i}" + ) - moves_reference = [ - [ - "g1f3", - "g7g6", - "b2b3", - "f8g7", - "b1c3", - "e7e5", - "c1b2", - "e5e4", - "f3g1", - "d7d6", - "a1b1", - "a7a5", - "c3e4", - "g7b2", - "b1b2", - "g8f6", - "e4f6", - "d8f6", - "b2b1", - "a8a6", - "e2e3", - "a6b6", - "d2d4", - "a5a4", - "b3a4", - "b8c6", - "b1b6", - "c7b6", - "f1b5", - "c8d7", - "b5c6", - "d7c6", - "g1f3", - "c6a4", - "e1g1", - "e8g8", - "f1e1", - "f8c8", - "e1e2", - "d6d5", - "f3e5", - "f6f5", - "d1d3", - "a4c2", - "d3f5", - "c2f5", - "h2h3", - "b6b5", - "e2b2", - "f7f6", - "e5g4", - "c8c6", - "g4h6", - "g8g7", - "h6f5", - "g6f5", - "b2b5", - "c6c7", - "b5d5", - "g7g6", - "f2f4", - "g6h5", - "d5f5", - "h5h4", - "f5f6", - "h4g3", - "d4d5", - "c7c1", - ], - [ - "e2e4", - "e7e5", - "g1f3", - "b8c6", - "f1b5", - "g8f6", - "e1g1", - "f8c5", - "d2d3", - "d7d6", - "h2h3", - "h7h6", - "c2c3", - "e8g8", - ], + # Verify all 26 clocks (stored as seconds) + clock_seconds_reference = [ + 30, + 30, + 29, + 30, + 28, + 30, + 27, + 30, + 26, + 29, + 25, + 29, + 23, + 27, + 22, + 26, + 21, + 25, + 18, + 23, + 14, + 23, + 11, + 23, + 2, + 21, ] + for i, expected_seconds in enumerate(clock_seconds_reference): + self.assertAlmostEqual( + game.clocks[i], + float(expected_seconds), + places=1, + msg=f"Clock mismatch at move {i}", + ) - comments_reference = [ - [None for _ in range(len(moves_reference[0]))], - [ - "asdf", - None, - None, - None, - None, - None, - "hello", - None, - None, - None, - None, - None, - None, - None, - ], - ] + def test_castling_rights_through_game(self): + """Test castling rights through multiple positions including king and rook moves.""" + pgn = "1. e4 e5 2. Bc4 c6 3. Nf3 d6 4. Rg1 f6 5. Rh1 g6 6. Ke2 b6 7. Ke1 g5 1-0" - self.assertTrue(extractor[0].comments == comments_reference[0]) - self.assertTrue(extractor[1].comments == comments_reference[1]) + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] - self.assertTrue( - [str(move) for move in extractor[0].moves] == moves_reference[0] - ) - self.assertTrue( - [str(move) for move in extractor[1].moves] == moves_reference[1] - ) - assert extractor[0].position_status is not None # appease the type checker - - self.assertTrue(extractor[0].position_status.is_checkmate) - self.assertFalse(extractor[0].position_status.is_stalemate) - self.assertTrue(extractor[0].position_status.is_game_over) - self.assertTrue(extractor[0].position_status.legal_move_count == 0) - self.assertTrue(extractor[0].position_status.turn == 1) - self.assertTrue(extractor[0].position_status.insufficient_material == (0, 0)) - - self.assertTrue(extractor[1].position_status is None) - extractor[1].update_position_status() - assert extractor[1].position_status is not None # appease the type checker - self.assertFalse(extractor[1].position_status.is_checkmate) - self.assertFalse(extractor[1].position_status.is_stalemate) - self.assertFalse(extractor[1].position_status.is_game_over) - self.assertTrue(extractor[1].position_status.legal_move_count == 36) - self.assertTrue(extractor[1].position_status.turn == 1) - self.assertTrue(extractor[1].position_status.insufficient_material == (0, 0)) - - def test_castling(self): - pgn_moves = """ - 1. e4 e5 2. Bc4 c6 3. Nf3 d6 4. Rg1 f6 5. Rh1 g6 6. Ke2 b6 7. Ke1 g5 - """ - - extractor = self.run_extractor(pgn_moves) - - castling_reference = [ - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, False, True, True), - (True, False, True, True), - (True, False, True, True), - (True, False, True, True), - (False, False, True, True), - (False, False, True, True), - (False, False, True, True), - (False, False, True, True), - ] + # 14 moves + 1 initial = 15 positions + self.assertEqual(game.num_positions, 15) - self.assertTrue(extractor.castling_rights == castling_reference) + # Castling order: [K, Q, k, q] + # Positions 0-6 (initial through 3. Nf3): all castling rights intact + for pos in range(7): + self.assertTrue( + all(game.castling[pos]), + f"Position {pos}: all castling should be intact", + ) - def test_parse_game_moves_arrow_chunked_array(self): + # Position 7 (after 4. Rg1): white kingside lost + self.assertFalse(game.castling[7, 0]) # White K gone + self.assertTrue(game.castling[7, 1]) # White Q intact + self.assertTrue(game.castling[7, 2]) # Black k intact + self.assertTrue(game.castling[7, 3]) # Black q intact + + # Positions 8-10: same (Rh1 restores rook but not castling rights) + for pos in [8, 9, 10]: + self.assertFalse(game.castling[pos, 0]) # White K still gone + self.assertTrue(game.castling[pos, 1]) # White Q intact + + # Position 11 (after 6. Ke2): white both castling lost + self.assertFalse(game.castling[11, 0]) # White K + self.assertFalse(game.castling[11, 1]) # White Q + self.assertTrue(game.castling[11, 2]) # Black k + self.assertTrue(game.castling[11, 3]) # Black q + + # Position 12-14: white castling remains lost (Ke1 doesn't restore it) + for pos in [12, 13, 14]: + self.assertFalse(game.castling[pos, 0]) + self.assertFalse(game.castling[pos, 1]) + + def test_parse_games_from_strings_with_status(self): + """Test parse_games_from_strings with checkmate, move verification, and status.""" pgns = [ + # Game 0: ends in checkmate "1. Nf3 g6 2. b3 Bg7 3. Nc3 e5 4. Bb2 e4 5. Ng1 d6 6. Rb1 a5 7. Nxe4 Bxb2 8. Rxb2 Nf6 9. Nxf6+ Qxf6 10. Rb1 Ra6 11. e3 Rb6 12. d4 a4 13. bxa4 Nc6 14. Rxb6 cxb6 15. Bb5 Bd7 16. Bxc6 Bxc6 17. Nf3 Bxa4 18. O-O O-O 19. Re1 Rc8 20. Re2 d5 21. Ne5 Qf5 22. Qd3 Bxc2 23. Qxf5 Bxf5 24. h3 b5 25. Rb2 f6 26. Ng4 Rc6 27. Nh6+ Kg7 28. Nxf5+ gxf5 29. Rxb5 Rc7 30. Rxd5 Kg6 31. f4 Kh5 32. Rxf5+ Kh4 33. Rxf6 Kg3 34. d5 Rc1# 0-1", - "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O", - ] - - # Create a PyArrow ChunkedArray - arrow_array = pa.array(pgns, type=pa.string()) - chunked_array = pa.chunked_array([arrow_array]) - - extractors = ( - rust_pgn_reader_python_binding.parse_game_moves_arrow_chunked_array( - chunked_array - ) - ) - - moves_reference = [ - [ - "g1f3", - "g7g6", - "b2b3", - "f8g7", - "b1c3", - "e7e5", - "c1b2", - "e5e4", - "f3g1", - "d7d6", - "a1b1", - "a7a5", - "c3e4", - "g7b2", - "b1b2", - "g8f6", - "e4f6", - "d8f6", - "b2b1", - "a8a6", - "e2e3", - "a6b6", - "d2d4", - "a5a4", - "b3a4", - "b8c6", - "b1b6", - "c7b6", - "f1b5", - "c8d7", - "b5c6", - "d7c6", - "g1f3", - "c6a4", - "e1g1", - "e8g8", - "f1e1", - "f8c8", - "e1e2", - "d6d5", - "f3e5", - "f6f5", - "d1d3", - "a4c2", - "d3f5", - "c2f5", - "h2h3", - "b6b5", - "e2b2", - "f7f6", - "e5g4", - "c8c6", - "g4h6", - "g8g7", - "h6f5", - "g6f5", - "b2b5", - "c6c7", - "b5d5", - "g7g6", - "f2f4", - "g6h5", - "d5f5", - "h5h4", - "f5f6", - "h4g3", - "d4d5", - "c7c1", - ], - [ - "e2e4", - "e7e5", - "g1f3", - "b8c6", - "f1b5", - "g8f6", - "e1g1", - "f8c5", - "d2d3", - "d7d6", - "h2h3", - "h7h6", - "c2c3", - "e8g8", - ], - ] - - self.assertTrue( - [str(move) for move in extractors[0].moves] == moves_reference[0] - ) - self.assertTrue( - [str(move) for move in extractors[1].moves] == moves_reference[1] - ) - - comments_reference = [ - [None for _ in range(len(moves_reference[0]))], - [ - "asdf", - None, - None, - None, - None, - None, - "hello", - None, - None, - None, - None, - None, - None, - None, - ], + # Game 1: no checkmate, game abandoned mid-play + "1. e4 e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 1-0", ] - - self.assertTrue(extractors[0].comments == comments_reference[0]) - self.assertTrue(extractors[1].comments == comments_reference[1]) - - extractors[0].update_position_status() # Ensure status is calculated - assert extractors[0].position_status is not None # appease the type checker - self.assertTrue(extractors[0].position_status.is_checkmate) - self.assertFalse(extractors[0].position_status.is_stalemate) - self.assertTrue(extractors[0].position_status.is_game_over) - self.assertTrue(extractors[0].position_status.legal_move_count == 0) - self.assertTrue( - extractors[0].position_status.turn == True - ) # White's turn, but Black delivered checkmate - self.assertTrue( - extractors[0].position_status.insufficient_material == (False, False) - ) - - self.assertTrue( - extractors[1].position_status is None - ) # Not set by default for parse_game_moves_arrow_chunked_array - extractors[1].update_position_status() - assert extractors[1].position_status is not None # appease the type checker - self.assertFalse(extractors[1].position_status.is_checkmate) - self.assertFalse(extractors[1].position_status.is_stalemate) - self.assertFalse(extractors[1].position_status.is_game_over) - self.assertTrue(extractors[1].position_status.legal_move_count == 36) - self.assertTrue(extractors[1].position_status.turn == True) # White's turn - self.assertTrue( - extractors[1].position_status.insufficient_material == (False, False) - ) + result = rust_pgn_reader_python_binding.parse_games_from_strings(pgns) + + self.assertEqual(result.num_games, 2) + + # Game 0: checkmate + game0 = result[0] + self.assertTrue(game0.is_valid) + self.assertTrue(game0.is_checkmate) + self.assertFalse(game0.is_stalemate) + self.assertTrue(game0.is_game_over) + self.assertEqual(game0.legal_move_count, 0) + self.assertEqual(game0.outcome, "Black") + # Verify some key moves + moves0 = game0.moves_uci() + self.assertEqual(len(moves0), 68) + self.assertEqual(moves0[0], "g1f3") + self.assertEqual(moves0[-1], "c7c1") + + # Game 1: no checkmate + game1 = result[1] + self.assertTrue(game1.is_valid) + self.assertFalse(game1.is_checkmate) + self.assertFalse(game1.is_stalemate) + self.assertFalse(game1.is_game_over) + self.assertEqual(game1.outcome, "White") + moves1 = game1.moves_uci() + self.assertEqual(len(moves1), 14) + self.assertEqual(moves1[0], "e2e4") if __name__ == "__main__": diff --git a/src/visitor.rs b/src/visitor.rs new file mode 100644 index 0000000..b78fa0a --- /dev/null +++ b/src/visitor.rs @@ -0,0 +1,846 @@ +//! SoA (Struct-of-Arrays) visitor for PGN parsing. +//! +//! This module provides a memory-efficient parsing approach that writes +//! directly to shared buffers instead of allocating per-game Vec structures. +//! Used by `parse_games` for optimal performance. + +use crate::board_serialization::{ + get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, +}; +use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; +use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; +use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; +use std::collections::HashMap; +use std::ops::ControlFlow; + +/// Configuration for what optional data to store during parsing. +#[derive(Clone, Debug)] +pub struct ParseConfig { + pub store_comments: bool, + pub store_legal_moves: bool, +} + +/// Accumulated buffers for multiple parsed games. +/// +/// This struct holds all data in a struct-of-arrays layout, optimized for: +/// - Efficient thread-local accumulation during parallel parsing +/// - Fast merging of thread-local buffers via `extend_from_slice` +/// - Direct conversion to NumPy arrays without intermediate allocations +#[derive(Default, Clone)] +pub struct Buffers { + // Board state arrays (one entry per position) + pub boards: Vec, // Flattened: 64 bytes per position + pub castling: Vec, // Flattened: 4 bools per position [K,Q,k,q] + pub en_passant: Vec, // Per position: -1 or file 0-7 + pub halfmove_clock: Vec, // Per position + pub turn: Vec, // Per position: true=white + + // Move arrays (one entry per move) + pub from_squares: Vec, + pub to_squares: Vec, + pub promotions: Vec, // -1 for no promotion + pub clocks: Vec, // NaN for missing + pub evals: Vec, // NaN for missing + + // Per-game data + pub move_counts: Vec, // Number of moves per game + pub position_counts: Vec, // Number of positions per game + pub is_checkmate: Vec, + pub is_stalemate: Vec, + pub is_insufficient: Vec, // Flattened: 2 bools per game [white, black] + pub legal_move_count: Vec, + pub valid: Vec, + pub headers: Vec>, + pub outcome: Vec>, // "White", "Black", "Draw", "Unknown", or None + + // Optional: raw text comments (per-move), only populated when store_comments=true + pub comments: Vec>, + + // Optional: legal moves at each position, only populated when store_legal_moves=true + // Stored as flat arrays with CSR-style offsets + pub legal_move_from_squares: Vec, + pub legal_move_to_squares: Vec, + pub legal_move_promotions: Vec, + pub legal_move_counts: Vec, // Number of legal moves per position +} + +impl Buffers { + /// Create a new Buffers with pre-allocated capacity. + /// + /// # Arguments + /// * `estimated_games` - Expected number of games + /// * `moves_per_game` - Expected average moves per game (default: 70) + /// * `config` - Configuration for optional features + pub fn with_capacity( + estimated_games: usize, + moves_per_game: usize, + config: &ParseConfig, + ) -> Self { + let estimated_moves = estimated_games * moves_per_game; + let estimated_positions = estimated_moves + estimated_games; // +1 initial position per game + + Buffers { + // Board state arrays + boards: Vec::with_capacity(estimated_positions * 64), + castling: Vec::with_capacity(estimated_positions * 4), + en_passant: Vec::with_capacity(estimated_positions), + halfmove_clock: Vec::with_capacity(estimated_positions), + turn: Vec::with_capacity(estimated_positions), + + // Move arrays + from_squares: Vec::with_capacity(estimated_moves), + to_squares: Vec::with_capacity(estimated_moves), + promotions: Vec::with_capacity(estimated_moves), + clocks: Vec::with_capacity(estimated_moves), + evals: Vec::with_capacity(estimated_moves), + + // Per-game data + move_counts: Vec::with_capacity(estimated_games), + position_counts: Vec::with_capacity(estimated_games), + is_checkmate: Vec::with_capacity(estimated_games), + is_stalemate: Vec::with_capacity(estimated_games), + is_insufficient: Vec::with_capacity(estimated_games * 2), + legal_move_count: Vec::with_capacity(estimated_games), + valid: Vec::with_capacity(estimated_games), + headers: Vec::with_capacity(estimated_games), + outcome: Vec::with_capacity(estimated_games), + + // Optional comments + comments: if config.store_comments { + Vec::with_capacity(estimated_moves) + } else { + Vec::new() + }, + + // Optional legal moves + legal_move_from_squares: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_to_squares: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_promotions: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_counts: if config.store_legal_moves { + Vec::with_capacity(estimated_positions) + } else { + Vec::new() + }, + } + } + + /// Number of games in this buffer. + pub fn num_games(&self) -> usize { + self.headers.len() + } + + /// Total number of moves across all games. + #[allow(dead_code)] + pub fn total_moves(&self) -> usize { + self.from_squares.len() + } + + /// Total number of positions across all games. + pub fn total_positions(&self) -> usize { + self.boards.len() / 64 + } + + /// Compute CSR-style offsets from counts. + pub fn compute_move_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.move_counts.len() + 1); + offsets.push(0); + for &count in &self.move_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } + + /// Compute CSR-style offsets from position counts. + pub fn compute_position_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.position_counts.len() + 1); + offsets.push(0); + for &count in &self.position_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } + + /// Compute CSR-style offsets from legal move counts (per position). + pub fn compute_legal_move_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.legal_move_counts.len() + 1); + offsets.push(0); + for &count in &self.legal_move_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } + + /// Total number of legal moves stored across all positions. + pub fn total_legal_moves(&self) -> usize { + self.legal_move_from_squares.len() + } +} + +/// Visitor that writes directly to shared Buffers. +/// +/// This visitor does not allocate any per-game Vec structures. +/// All data is appended directly to the shared Buffers. +pub struct GameVisitor<'a> { + buffers: &'a mut Buffers, + config: ParseConfig, + pos: Chess, + valid_moves: bool, + current_headers: Vec<(String, String)>, + current_outcome: Option, + // Track counts for current game + current_move_count: u32, + current_position_count: u32, +} + +impl<'a> GameVisitor<'a> { + pub fn new(buffers: &'a mut Buffers, config: &ParseConfig) -> Self { + GameVisitor { + buffers, + config: config.clone(), + pos: Chess::default(), + valid_moves: true, + current_headers: Vec::with_capacity(10), + current_outcome: None, + current_move_count: 0, + current_position_count: 0, + } + } + + /// Record current board state to buffers. + fn push_board_state(&mut self) { + self.buffers + .boards + .extend_from_slice(&serialize_board(&self.pos)); + let castling = get_castling_rights(&self.pos); + self.buffers.castling.extend_from_slice(&castling); + self.buffers.en_passant.push(get_en_passant_file(&self.pos)); + self.buffers + .halfmove_clock + .push(get_halfmove_clock(&self.pos)); + self.buffers.turn.push(get_turn(&self.pos)); + self.current_position_count += 1; + + // Store legal moves if enabled + if self.config.store_legal_moves { + self.push_legal_moves(); + } + } + + /// Record legal moves at current position to buffers. + fn push_legal_moves(&mut self) { + let legal_moves = self.pos.legal_moves(); + let mut count: u32 = 0; + for m in legal_moves { + let uci_move_obj = UciMove::from_standard(m); + if let UciMove::Normal { + from, + to, + promotion, + } = uci_move_obj + { + self.buffers.legal_move_from_squares.push(from as u8); + self.buffers.legal_move_to_squares.push(to as u8); + self.buffers + .legal_move_promotions + .push(promotion.map(|p| p as i8).unwrap_or(-1)); + count += 1; + } + } + self.buffers.legal_move_counts.push(count); + } + + /// Record move data to buffers. + fn push_move(&mut self, from: u8, to: u8, promotion: Option) { + self.buffers.from_squares.push(from); + self.buffers.to_squares.push(to); + self.buffers + .promotions + .push(promotion.map(|p| p as i8).unwrap_or(-1)); + // Push placeholders for clock and eval (will be overwritten by comment()) + self.buffers.clocks.push(f32::NAN); + self.buffers.evals.push(f32::NAN); + // Push comment placeholder if enabled (will be overwritten by comment()) + if self.config.store_comments { + self.buffers.comments.push(None); + } + self.current_move_count += 1; + } + + /// Record final position status. + fn update_position_status(&mut self) { + self.buffers.is_checkmate.push(self.pos.is_checkmate()); + self.buffers.is_stalemate.push(self.pos.is_stalemate()); + self.buffers + .is_insufficient + .push(self.pos.has_insufficient_material(Color::White)); + self.buffers + .is_insufficient + .push(self.pos.has_insufficient_material(Color::Black)); + self.buffers + .legal_move_count + .push(self.pos.legal_moves().len() as u16); + } + + /// Finalize current game - record per-game data. + fn finalize_game(&mut self) { + self.buffers.move_counts.push(self.current_move_count); + self.buffers + .position_counts + .push(self.current_position_count); + self.buffers.valid.push(self.valid_moves); + self.buffers.outcome.push(self.current_outcome.take()); + + // Convert headers to HashMap + let header_map: HashMap = self.current_headers.drain(..).collect(); + self.buffers.headers.push(header_map); + } +} + +impl Visitor for GameVisitor<'_> { + type Tags = Vec<(String, String)>; + type Movetext = (); + type Output = bool; + + fn begin_tags(&mut self) -> ControlFlow { + self.current_headers.clear(); + ControlFlow::Continue(Vec::with_capacity(10)) + } + + fn tag( + &mut self, + tags: &mut Self::Tags, + key: &[u8], + value: RawTag<'_>, + ) -> ControlFlow { + let key_str = String::from_utf8_lossy(key).into_owned(); + let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); + tags.push((key_str, value_str)); + ControlFlow::Continue(()) + } + + fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { + self.current_headers = tags; + self.valid_moves = true; + self.current_outcome = None; + self.current_move_count = 0; + self.current_position_count = 0; + + // Determine castling mode from Variant header (case-insensitive) + let castling_mode = self + .current_headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) + .and_then(|(_, v)| { + let v_lower = v.to_lowercase(); + if v_lower == "chess960" { + Some(CastlingMode::Chess960) + } else { + None + } + }) + .unwrap_or(CastlingMode::Standard); + + // Try to parse FEN from headers, fall back to default position + let fen_header = self + .current_headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) + .map(|(_, v)| v.as_str()); + + if let Some(fen_str) = fen_header { + match fen_str.parse::() { + Ok(fen) => match fen.into_position(castling_mode) { + Ok(pos) => self.pos = pos, + Err(e) => { + eprintln!("invalid FEN position: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + }, + Err(e) => { + eprintln!("failed to parse FEN: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + } + } else { + self.pos = Chess::default(); + } + + // Record initial board state + self.push_board_state(); + ControlFlow::Continue(()) + } + + fn san( + &mut self, + _movetext: &mut Self::Movetext, + san_plus: SanPlus, + ) -> ControlFlow { + if self.valid_moves { + match san_plus.san.to_move(&self.pos) { + Ok(m) => { + self.pos.play_unchecked(m); + + // Record board state after move + self.push_board_state(); + + let uci_move_obj = UciMove::from_standard(m); + match uci_move_obj { + UciMove::Normal { + from, + to, + promotion, + } => { + self.push_move(from as u8, to as u8, promotion.map(|p| p as u8)); + } + _ => { + eprintln!( + "Unexpected UCI move type: {:?}. Game moves might be invalid.", + uci_move_obj + ); + self.valid_moves = false; + } + } + } + Err(err) => { + eprintln!("error in game: {} {}", err, san_plus); + self.valid_moves = false; + } + } + } + ControlFlow::Continue(()) + } + + fn comment( + &mut self, + _movetext: &mut Self::Movetext, + comment: RawComment<'_>, + ) -> ControlFlow { + if let Ok((_, parsed_comments)) = parse_comments(comment.as_bytes()) { + let mut move_comments = if self.config.store_comments { + Some(String::new()) + } else { + None + }; + + for content in parsed_comments { + match content { + CommentContent::Tag(tag_content) => match tag_content { + ParsedTag::Eval(eval_value) => { + // Update the last eval entry + if let Some(last_eval) = self.buffers.evals.last_mut() { + *last_eval = eval_value as f32; + } + } + ParsedTag::ClkTime { + hours, + minutes, + seconds, + } => { + // Convert to seconds and update the last clock entry + if let Some(last_clk) = self.buffers.clocks.last_mut() { + *last_clk = + hours as f32 * 3600.0 + minutes as f32 * 60.0 + seconds as f32; + } + } + ParsedTag::Mate(mate_value) => { + // Mate scores stored as text in comments (matching old API behavior) + if let Some(ref mut comments) = move_comments { + if !comments.is_empty() && !comments.ends_with(' ') { + comments.push(' '); + } + comments.push_str(&format!("[Mate {}]", mate_value)); + } + } + }, + CommentContent::Text(text) => { + if let Some(ref mut comments) = move_comments { + if !text.trim().is_empty() { + if !comments.is_empty() { + comments.push(' '); + } + comments.push_str(&text); + } + } + } + } + } + + // Update the last comment entry if comments are enabled + if let Some(comment_text) = move_comments { + if let Some(last_comment) = self.buffers.comments.last_mut() { + *last_comment = Some(comment_text); + } + } + } + ControlFlow::Continue(()) + } + + fn begin_variation( + &mut self, + _movetext: &mut Self::Movetext, + ) -> ControlFlow { + ControlFlow::Continue(Skip(true)) // Skip variations, stay in mainline + } + + fn outcome( + &mut self, + _movetext: &mut Self::Movetext, + _outcome: Outcome, + ) -> ControlFlow { + self.current_outcome = Some(match _outcome { + Outcome::Known(known) => match known { + KnownOutcome::Decisive { winner } => format!("{:?}", winner), + KnownOutcome::Draw => "Draw".to_string(), + }, + Outcome::Unknown => "Unknown".to_string(), + }); + self.update_position_status(); + ControlFlow::Continue(()) + } + + fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { + // Handle case where outcome() was not called (e.g., incomplete game) + if self.buffers.is_checkmate.len() < self.buffers.headers.len() + 1 { + self.update_position_status(); + } + self.finalize_game(); + self.valid_moves + } +} + +/// Parse a single game directly into Buffers. +pub fn parse_game_to_buffers( + pgn: &str, + buffers: &mut Buffers, + config: &ParseConfig, +) -> Result { + use pgn_reader::Reader; + use std::io::Cursor; + + let mut reader = Reader::new(Cursor::new(pgn)); + let mut visitor = GameVisitor::new(buffers, config); + + match reader.read_game(&mut visitor) { + Ok(Some(valid)) => Ok(valid), + Ok(None) => Err("No game found in PGN".to_string()), + Err(err) => Err(format!("Parsing error: {}", err)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn default_config() -> ParseConfig { + ParseConfig { + store_comments: false, + store_legal_moves: false, + } + } + + #[test] + fn test_parse_simple_game() { + let pgn = r#"[Event "Test"] +[White "Player1"] +[Black "Player2"] +[Result "1-0"] + +1. e4 e5 2. Nf3 Nc6 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.num_games(), 1); + assert_eq!(buffers.move_counts[0], 4); // 4 moves + assert_eq!(buffers.position_counts[0], 5); // 5 positions (initial + 4 moves) + assert_eq!(buffers.total_moves(), 4); + assert_eq!(buffers.total_positions(), 5); + assert_eq!(buffers.outcome[0], Some("White".to_string())); + } + + #[test] + fn test_parse_game_with_annotations() { + let pgn = r#"[Event "Test"] +[Result "1-0"] + +1. e4 { [%eval 0.17] [%clk 0:03:00] } 1... e5 { [%eval 0.19] [%clk 0:02:58] } 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert_eq!(buffers.total_moves(), 2); + + // Check that evals were parsed + assert!(!buffers.evals[0].is_nan()); + assert!((buffers.evals[0] - 0.17).abs() < 0.01); + assert!(!buffers.evals[1].is_nan()); + assert!((buffers.evals[1] - 0.19).abs() < 0.01); + + // Check that clocks were parsed (3 minutes = 180 seconds) + assert!(!buffers.clocks[0].is_nan()); + assert!((buffers.clocks[0] - 180.0).abs() < 0.01); + } + + #[test] + fn test_multiple_games_in_one_buffer() { + let pgn1 = r#"[Event "Game1"] +[Result "1-0"] + +1. e4 e5 1-0"#; + + let pgn2 = r#"[Event "Game2"] +[Result "0-1"] + +1. d4 d5 2. c4 0-1"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(2, 70, &config); + + parse_game_to_buffers(pgn1, &mut buffers, &config).unwrap(); + parse_game_to_buffers(pgn2, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.num_games(), 2); + assert_eq!(buffers.total_moves(), 5); // 2 + 3 moves + assert_eq!(buffers.move_counts, vec![2, 3]); + assert_eq!(buffers.outcome[0], Some("White".to_string())); + assert_eq!(buffers.outcome[1], Some("Black".to_string())); + } + + #[test] + fn test_compute_offsets() { + let mut buffers = Buffers::default(); + buffers.move_counts = vec![4, 6, 3]; + buffers.position_counts = vec![5, 7, 4]; + + let move_offsets = buffers.compute_move_offsets(); + assert_eq!(move_offsets, vec![0, 4, 10, 13]); + + let pos_offsets = buffers.compute_position_offsets(); + assert_eq!(pos_offsets, vec![0, 5, 12, 16]); + } + + #[test] + fn test_outcome_without_headers() { + // PGN without Result header - outcome comes from movetext + let pgn = "1. e4 e5 2. Nf3 Nc6 0-1"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.outcome[0], Some("Black".to_string())); + } + + #[test] + fn test_outcome_draw() { + let pgn = "1. e4 e5 1/2-1/2"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.outcome[0], Some("Draw".to_string())); + } + + #[test] + fn test_comments_disabled() { + let pgn = r#"1. e4 { a comment } e5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert!(buffers.comments.is_empty()); + } + + #[test] + fn test_comments_enabled() { + let pgn = r#"1. e4 { a comment } 1... e5 { [%eval 0.19] } 1-0"#; + + let config = ParseConfig { + store_comments: true, + store_legal_moves: false, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.comments.len(), 2); + // Raw text from PGN includes surrounding spaces from the parser + assert_eq!(buffers.comments[0], Some(" a comment ".to_string())); + // The second comment only has an eval tag, so text portion is empty + assert_eq!(buffers.comments[1], Some("".to_string())); + } + + #[test] + fn test_legal_moves_disabled() { + let pgn = "1. e4 e5 1-0"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert!(buffers.legal_move_from_squares.is_empty()); + assert!(buffers.legal_move_counts.is_empty()); + } + + #[test] + fn test_legal_moves_enabled() { + let pgn = "1. e4 1-0"; + + let config = ParseConfig { + store_comments: false, + store_legal_moves: true, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + // 2 positions: initial + after e4 + assert_eq!(buffers.legal_move_counts.len(), 2); + // Initial position has 20 legal moves + assert_eq!(buffers.legal_move_counts[0], 20); + // After e4, black has 20 legal moves + assert_eq!(buffers.legal_move_counts[1], 20); + // Total legal moves stored + assert_eq!(buffers.legal_move_from_squares.len(), 40); + } + + #[test] + fn test_parse_game_without_headers() { + let pgn = "1. Nf3 d5 2. e4 c5 3. exd5 e5 4. dxe6 0-1"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.num_games(), 1); + assert_eq!(buffers.total_moves(), 7); + assert_eq!(buffers.outcome[0], Some("Black".to_string())); + } + + #[test] + fn test_parse_game_with_standard_fen() { + // A game starting from a mid-game position + let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 a6 4. Ba4 Nf6 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.total_moves(), 4); + } + + #[test] + fn test_parse_chess960_game() { + // Chess960 game with custom starting position + let pgn = r#"[Variant "chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + result.unwrap(), + "Chess960 moves should be valid with proper FEN" + ); + assert_eq!(buffers.total_moves(), 6); + } + + #[test] + fn test_parse_chess960_variant_case_insensitive() { + // Test that variant detection is case-insensitive + let pgn = r#"[Variant "Chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap(), "Should handle Chess960 case variations"); + } + + #[test] + fn test_parse_invalid_fen_falls_back() { + // Invalid FEN should fall back to default and mark invalid + let pgn = r#"[FEN "invalid fen string"] + +1. e4 e5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + !result.unwrap(), + "Should mark as invalid when FEN parsing fails" + ); + } + + #[test] + fn test_fen_header_case_insensitive() { + // FEN header key should be case-insensitive + let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap(), "Should handle lowercase 'fen' header"); + } + + #[test] + fn test_parse_game_with_custom_fen_no_variant() { + // Standard chess from a mid-game position (no Variant header) + // Position after 1.e4 e5 2.Nf3 Nc6 3.Bb5 (Ruy Lopez) + let pgn = r#"[Event "Test Game"] +[FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] + +3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + result.unwrap(), + "Standard game with custom FEN should be valid" + ); + assert_eq!(buffers.total_moves(), 5); // a6, Ba4, Nf6, O-O, Be7 + } +}