From f4b2ad849946f434ac6841e04cd1a047979c2140 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 20:39:02 -0500 Subject: [PATCH 01/31] Add support for a different API returning numpy objects --- Cargo.lock | 1 + Cargo.toml | 1 + python/__init__.py | 5 + python/wrapper.py | 301 ++++++++++++++++++++++++++ src/bench_parse_games_flat.py | 361 ++++++++++++++++++++++++++++++++ src/board_serialization.rs | 134 ++++++++++++ src/example_parse_games_flat.py | 178 ++++++++++++++++ src/lib.rs | 324 +++++++++++++++++++++++++++- src/test.py | 298 ++++++++++++++++++++++++++ 9 files changed, 1599 insertions(+), 4 deletions(-) create mode 100644 python/__init__.py create mode 100644 python/wrapper.py create mode 100644 src/bench_parse_games_flat.py create mode 100644 src/board_serialization.rs create mode 100644 src/example_parse_games_flat.py diff --git a/Cargo.lock b/Cargo.lock index 1a1854d..7f604ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1329,6 +1329,7 @@ dependencies = [ "criterion", "nom", "num_cpus", + "numpy", "parquet", "pgn-reader", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 99462c4..aa9257b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ rayon = "1.11" num_cpus = "1.17" arrow-array = "57" pyo3-arrow = "0.15" +numpy = "0.27" [dev-dependencies] criterion = "0.8" diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 0000000..74ad9ae --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,5 @@ +"""Python wrapper utilities for rust_pgn_reader_python_binding.""" + +from .wrapper import add_ergonomic_methods, GameView, BatchSlice + +__all__ = ["add_ergonomic_methods", "GameView", "BatchSlice"] diff --git a/python/wrapper.py b/python/wrapper.py new file mode 100644 index 0000000..fc1ba45 --- /dev/null +++ b/python/wrapper.py @@ -0,0 +1,301 @@ +""" +Ergonomic Python wrappers for rust_pgn_reader_python_binding.ParsedGames. + +Usage: + import rust_pgn_reader_python_binding as pgn + from python.wrapper import add_ergonomic_methods, GameView + + result = pgn.parse_games_flat(chunked_array) + add_ergonomic_methods(type(result)) + + for game in result: + print(game.headers) +""" + +from __future__ import annotations + +import numpy as np +from typing import Iterator, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from rust_pgn_reader_python_binding import ParsedGames + + +class GameView: + """ + Zero-copy view into a single game's data within a ParsedGames result. + + Board indexing note: Boards use square indexing (a1=0, h8=63). + To convert to rank/file array indexing used by some Python code: + rank = square // 8 + file = square % 8 + # For [7-rank, file] layout: board_2d[7 - rank, file] + """ + + __slots__ = ( + "_data", + "_idx", + "_move_start", + "_move_end", + "_pos_start", + "_pos_end", + ) + + def __init__(self, data: "ParsedGames", idx: int): + self._data = data + self._idx = idx + self._move_start = int(data.move_offsets[idx]) + self._move_end = int(data.move_offsets[idx + 1]) + self._pos_start = int(data.position_offsets[idx]) + self._pos_end = int(data.position_offsets[idx + 1]) + + def __len__(self) -> int: + """Number of moves in this game.""" + return self._move_end - self._move_start + + @property + def num_positions(self) -> int: + """Number of positions recorded for this game.""" + return self._pos_end - self._pos_start + + # === Board state views === + + @property + def boards(self) -> np.ndarray: + """Board positions, shape (num_positions, 8, 8).""" + return self._data.boards[self._pos_start : self._pos_end] + + @property + def initial_board(self) -> np.ndarray: + """Initial board, shape (8, 8).""" + return self._data.boards[self._pos_start] + + @property + def final_board(self) -> np.ndarray: + """Final board, shape (8, 8).""" + return self._data.boards[self._pos_end - 1] + + @property + def castling(self) -> np.ndarray: + """Castling rights [K,Q,k,q], shape (num_positions, 4).""" + return self._data.castling[self._pos_start : self._pos_end] + + @property + def en_passant(self) -> np.ndarray: + """En passant file (-1 if none), shape (num_positions,).""" + return self._data.en_passant[self._pos_start : self._pos_end] + + @property + def halfmove_clock(self) -> np.ndarray: + """Halfmove clock, shape (num_positions,).""" + return self._data.halfmove_clock[self._pos_start : self._pos_end] + + @property + def turn(self) -> np.ndarray: + """Side to move (True=white), shape (num_positions,).""" + return self._data.turn[self._pos_start : self._pos_end] + + # === Move views === + + @property + def from_squares(self) -> np.ndarray: + """From squares, shape (num_moves,).""" + return self._data.from_squares[self._move_start : self._move_end] + + @property + def to_squares(self) -> np.ndarray: + """To squares, shape (num_moves,).""" + return self._data.to_squares[self._move_start : self._move_end] + + @property + def promotions(self) -> np.ndarray: + """Promotions (-1=none), shape (num_moves,).""" + return self._data.promotions[self._move_start : self._move_end] + + @property + def clocks(self) -> np.ndarray: + """Clock times in seconds (NaN if missing), shape (num_moves,).""" + return self._data.clocks[self._move_start : self._move_end] + + @property + def evals(self) -> np.ndarray: + """Engine evals (NaN if missing), shape (num_moves,).""" + return self._data.evals[self._move_start : self._move_end] + + # === Per-game metadata === + + @property + def headers(self) -> Dict[str, str]: + """Raw PGN headers as dict.""" + return self._data.headers[self._idx] + + @property + def is_checkmate(self) -> bool: + """Final position is checkmate.""" + return bool(self._data.is_checkmate[self._idx]) + + @property + def is_stalemate(self) -> bool: + """Final position is stalemate.""" + return bool(self._data.is_stalemate[self._idx]) + + @property + def is_insufficient(self) -> tuple: + """Insufficient material (white, black).""" + return ( + bool(self._data.is_insufficient[self._idx, 0]), + bool(self._data.is_insufficient[self._idx, 1]), + ) + + @property + def legal_move_count(self) -> int: + """Legal moves in final position.""" + return int(self._data.legal_move_count[self._idx]) + + @property + def is_valid(self) -> bool: + """Whether game parsed successfully.""" + return bool(self._data.valid[self._idx]) + + # === Convenience methods === + + def move_uci(self, move_idx: int) -> str: + """Get UCI string for move at index.""" + files = "abcdefgh" + ranks = "12345678" + from_sq = int(self.from_squares[move_idx]) + to_sq = int(self.to_squares[move_idx]) + promo = int(self.promotions[move_idx]) + + uci = f"{files[from_sq % 8]}{ranks[from_sq // 8]}{files[to_sq % 8]}{ranks[to_sq // 8]}" + if promo >= 0: + promo_chars = {2: "n", 3: "b", 4: "r", 5: "q"} + uci += promo_chars.get(promo, "") + return uci + + def moves_uci(self) -> list: + """Get all moves as UCI strings.""" + return [self.move_uci(i) for i in range(len(self))] + + def __repr__(self) -> str: + white = self.headers.get("White", "?") + black = self.headers.get("Black", "?") + return ( + f"" + ) + + +class BatchSlice: + """Lazy iterator over a slice of games.""" + + __slots__ = ("_data", "_indices") + + def __init__(self, data: "ParsedGames", indices: range): + self._data = data + self._indices = indices + + def __iter__(self) -> Iterator[GameView]: + for i in self._indices: + yield GameView(self._data, i) + + def __len__(self) -> int: + return len(self._indices) + + def __repr__(self) -> str: + return f"" + + +# === Functions to add ergonomic methods to ParsedGames === + + +def _parsed_games_len(self) -> int: + """Number of games in result.""" + return len(self.move_offsets) - 1 + + +def _parsed_games_getitem(self, idx): + """Access game(s) by index or slice.""" + n_games = len(self.move_offsets) - 1 + if isinstance(idx, int): + if idx < 0: + idx += n_games + if not 0 <= idx < n_games: + raise IndexError(f"Game index {idx} out of range [0, {n_games})") + return GameView(self, idx) + elif isinstance(idx, slice): + start, stop, step = idx.indices(n_games) + return BatchSlice(self, range(start, stop, step)) + raise TypeError(f"Invalid index type: {type(idx)}") + + +def _parsed_games_iter(self) -> Iterator[GameView]: + """Iterate over all games.""" + for i in range(len(self.move_offsets) - 1): + yield GameView(self, i) + + +def _position_to_game(self, position_indices: np.ndarray) -> np.ndarray: + """ + Map position indices to game indices. + + Useful after shuffling/sampling positions to look up game metadata. + + Args: + position_indices: Array of indices into boards array + + Returns: + Array of game indices (same shape as input) + """ + return ( + np.searchsorted(self.position_offsets[:-1], position_indices, side="right") - 1 + ) + + +def _move_to_game(self, move_indices: np.ndarray) -> np.ndarray: + """ + Map move indices to game indices. + + Args: + move_indices: Array of indices into from_squares, to_squares, etc. + + Returns: + Array of game indices (same shape as input) + """ + return np.searchsorted(self.move_offsets[:-1], move_indices, side="right") - 1 + + +@property +def _num_games(self) -> int: + """Number of games.""" + return len(self.move_offsets) - 1 + + +@property +def _num_moves(self) -> int: + """Total moves across all games.""" + return int(self.move_offsets[-1]) + + +@property +def _num_positions(self) -> int: + """Total positions recorded.""" + return int(self.position_offsets[-1]) + + +def add_ergonomic_methods(parsed_games_class): + """ + Add ergonomic methods to the ParsedGames class. + + Call once after importing the module: + import rust_pgn_reader_python_binding as pgn + from python.wrapper import add_ergonomic_methods + add_ergonomic_methods(pgn.ParsedGames) + """ + parsed_games_class.__len__ = _parsed_games_len + parsed_games_class.__getitem__ = _parsed_games_getitem + parsed_games_class.__iter__ = _parsed_games_iter + parsed_games_class.position_to_game = _position_to_game + parsed_games_class.move_to_game = _move_to_game + parsed_games_class.num_games = _num_games + parsed_games_class.num_moves = _num_moves + parsed_games_class.num_positions = _num_positions diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games_flat.py new file mode 100644 index 0000000..10c0f9d --- /dev/null +++ b/src/bench_parse_games_flat.py @@ -0,0 +1,361 @@ +""" +Benchmark comparing parse_games_flat() vs parse_game_moves_arrow_chunked_array(). + +This benchmark measures: +1. Parsing speed (games/second) +2. Memory efficiency (bytes per position) +3. Data access patterns for ML workloads + +Usage: + python bench_parse_games_flat.py [parquet_file] + +If no parquet file is provided, synthetic PGN data will be generated. +""" + +import sys +import os +import time +import argparse +from typing import Optional + +import numpy as np +import pyarrow as pa + +import rust_pgn_reader_python_binding as pgn + +# Add python directory to path for wrapper imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) +from wrapper import add_ergonomic_methods + +# Patch ParsedGames with ergonomic methods +add_ergonomic_methods(pgn.ParsedGames) + + +def generate_synthetic_pgns(num_games: int, moves_per_game: int = 40) -> list[str]: + """Generate synthetic PGN games for benchmarking.""" + # A realistic game template + move_pairs = [ + ("e4", "e5"), + ("Nf3", "Nc6"), + ("Bb5", "a6"), + ("Ba4", "Nf6"), + ("O-O", "Be7"), + ("Re1", "b5"), + ("Bb3", "d6"), + ("c3", "O-O"), + ("h3", "Nb8"), + ("d4", "Nbd7"), + ("Nbd2", "Bb7"), + ("Bc2", "Re8"), + ("Nf1", "Bf8"), + ("Ng3", "g6"), + ("Bg5", "h6"), + ("Bd2", "Bg7"), + ("a4", "c5"), + ("d5", "c4"), + ("b4", "Nc5"), + ("Be3", "Qc7"), + ] + + pgns = [] + for i in range(num_games): + # Build movetext + moves = [] + num_pairs = min(moves_per_game // 2, len(move_pairs)) + for j in range(num_pairs): + white_move, black_move = move_pairs[j] + moves.append(f"{j + 1}. {white_move} {black_move}") + + movetext = " ".join(moves) + result = ["1-0", "0-1", "1/2-1/2"][i % 3] + + pgn_str = f"""[Event "Synthetic Game {i}"] +[White "Player{i * 2}"] +[Black "Player{i * 2 + 1}"] +[Result "{result}"] + +{movetext} {result}""" + pgns.append(pgn_str) + + return pgns + + +def load_parquet_pgns(file_path: str, limit: Optional[int] = None) -> pa.ChunkedArray: + """Load PGN strings from a parquet file.""" + import pyarrow.parquet as pq + + pf = pq.ParquetFile(file_path) + + # Try common column names for PGN/movetext + table = pf.read() + for col_name in ["movetext", "pgn", "moves", "game"]: + if col_name in table.column_names: + arr = table.column(col_name) + if limit: + arr = arr.slice(0, limit) + return arr + + raise ValueError( + f"Could not find PGN column. Available columns: {table.column_names}" + ) + + +def benchmark_parse_games_flat( + chunked_array: pa.ChunkedArray, num_threads: Optional[int] = None, warmup: int = 1 +) -> dict: + """Benchmark parse_games_flat().""" + # Warmup + for _ in range(warmup): + _ = pgn.parse_games_flat(chunked_array, num_threads=num_threads) + + # Timed run + start = time.perf_counter() + result = pgn.parse_games_flat(chunked_array, num_threads=num_threads) + elapsed = time.perf_counter() - start + + return { + "method": "parse_games_flat", + "elapsed_seconds": elapsed, + "num_games": result.num_games, + "num_moves": result.num_moves, + "num_positions": result.num_positions, + "games_per_second": result.num_games / elapsed, + "moves_per_second": result.num_moves / elapsed, + "positions_per_second": result.num_positions / elapsed, + "valid_games": int(result.valid.sum()), + "result": result, + } + + +def benchmark_parse_arrow_chunked( + chunked_array: pa.ChunkedArray, num_threads: Optional[int] = None, warmup: int = 1 +) -> dict: + """Benchmark parse_game_moves_arrow_chunked_array().""" + # Warmup + for _ in range(warmup): + _ = pgn.parse_game_moves_arrow_chunked_array( + chunked_array, num_threads=num_threads + ) + + # Timed run + start = time.perf_counter() + extractors = pgn.parse_game_moves_arrow_chunked_array( + chunked_array, num_threads=num_threads + ) + elapsed = time.perf_counter() - start + + num_games = len(extractors) + num_moves = sum(len(e.moves) for e in extractors) + num_positions = num_moves + num_games # Approximate + valid_games = sum(1 for e in extractors if e.valid_moves) + + return { + "method": "parse_arrow_chunked", + "elapsed_seconds": elapsed, + "num_games": num_games, + "num_moves": num_moves, + "num_positions": num_positions, + "games_per_second": num_games / elapsed, + "moves_per_second": num_moves / elapsed, + "positions_per_second": num_positions / elapsed, + "valid_games": valid_games, + "result": extractors, + } + + +def benchmark_data_access_flat(result) -> dict: + """Benchmark data access patterns for parse_games_flat result.""" + start = time.perf_counter() + + # Simulate ML data loading: access all boards + _ = result.boards.sum() + + # Access moves + _ = result.from_squares.sum() + _ = result.to_squares.sum() + + # Random position access + indices = np.random.randint(0, result.num_positions, size=1000) + _ = result.boards[indices] + + # Position-to-game mapping + _ = result.position_to_game(indices) + + elapsed = time.perf_counter() - start + return {"access_time": elapsed} + + +def benchmark_data_access_extractors(extractors: list) -> dict: + """Benchmark data access patterns for list of MoveExtractors.""" + start = time.perf_counter() + + # Simulate accessing all moves (requires iteration) + total = 0 + for e in extractors: + for m in e.moves: + total += m.from_square + m.to_square + + elapsed = time.perf_counter() - start + return {"access_time": elapsed} + + +def format_number(n: float) -> str: + """Format large numbers with K/M suffix.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.2f}M" + elif n >= 1_000: + return f"{n / 1_000:.2f}K" + else: + return f"{n:.2f}" + + +def print_results(results: dict, label: str): + """Print benchmark results.""" + print(f"\n{label}") + print("-" * 50) + print(f" Time: {results['elapsed_seconds']:.3f}s") + print(f" Games: {results['num_games']:,}") + print(f" Moves: {results['num_moves']:,}") + print(f" Positions: {results['num_positions']:,}") + print(f" Valid games: {results['valid_games']:,}") + print(f" Games/sec: {format_number(results['games_per_second'])}") + print(f" Moves/sec: {format_number(results['moves_per_second'])}") + print(f" Positions/sec: {format_number(results['positions_per_second'])}") + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark parse_games_flat() vs parse_game_moves_arrow_chunked_array()" + ) + parser.add_argument( + "parquet_file", + nargs="?", + help="Path to parquet file with PGN data (optional)", + ) + parser.add_argument( + "--num-games", + type=int, + default=10000, + help="Number of synthetic games to generate if no parquet file", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of games from parquet file", + ) + parser.add_argument( + "--threads", + type=int, + default=None, + help="Number of threads (default: all cores)", + ) + parser.add_argument( + "--warmup", + type=int, + default=1, + help="Number of warmup iterations", + ) + parser.add_argument( + "--skip-access", + action="store_true", + help="Skip data access benchmarks", + ) + args = parser.parse_args() + + print("=" * 60) + print("parse_games_flat() Benchmark") + print("=" * 60) + + # Load or generate data + if args.parquet_file: + print(f"\nLoading from: {args.parquet_file}") + try: + chunked_array = load_parquet_pgns(args.parquet_file, limit=args.limit) + print(f"Loaded {len(chunked_array):,} games") + except Exception as e: + print(f"Error loading parquet: {e}") + return 1 + else: + print(f"\nGenerating {args.num_games:,} synthetic games...") + pgns = generate_synthetic_pgns(args.num_games) + chunked_array = pa.chunked_array([pa.array(pgns)]) + print(f"Generated {len(chunked_array):,} games") + + print(f"Threads: {args.threads or 'all cores'}") + print(f"Warmup iterations: {args.warmup}") + + # Benchmark parse_games_flat + print("\n" + "=" * 60) + print("PARSING BENCHMARKS") + print("=" * 60) + + flat_results = benchmark_parse_games_flat( + chunked_array, num_threads=args.threads, warmup=args.warmup + ) + print_results(flat_results, "parse_games_flat()") + + # Benchmark parse_game_moves_arrow_chunked_array + arrow_results = benchmark_parse_arrow_chunked( + chunked_array, num_threads=args.threads, warmup=args.warmup + ) + print_results(arrow_results, "parse_game_moves_arrow_chunked_array()") + + # Comparison + print("\n" + "=" * 60) + print("COMPARISON") + print("=" * 60) + speedup = arrow_results["elapsed_seconds"] / flat_results["elapsed_seconds"] + print( + f"\nparse_games_flat() is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'}" + ) + + # Data access benchmarks + if not args.skip_access: + print("\n" + "=" * 60) + print("DATA ACCESS BENCHMARKS") + print("=" * 60) + + flat_access = benchmark_data_access_flat(flat_results["result"]) + print(f"\nFlat arrays access time: {flat_access['access_time']:.3f}s") + + extractor_access = benchmark_data_access_extractors(arrow_results["result"]) + print(f"Extractor list access time: {extractor_access['access_time']:.3f}s") + + access_speedup = extractor_access["access_time"] / flat_access["access_time"] + print(f"\nFlat arrays are {access_speedup:.2f}x faster for data access") + + # Memory usage (approximate) + print("\n" + "=" * 60) + print("MEMORY USAGE (approximate)") + print("=" * 60) + + flat_result = flat_results["result"] + flat_bytes = ( + flat_result.boards.nbytes + + flat_result.castling.nbytes + + flat_result.en_passant.nbytes + + flat_result.halfmove_clock.nbytes + + flat_result.turn.nbytes + + flat_result.from_squares.nbytes + + flat_result.to_squares.nbytes + + flat_result.promotions.nbytes + + flat_result.clocks.nbytes + + flat_result.evals.nbytes + + flat_result.move_offsets.nbytes + + flat_result.position_offsets.nbytes + ) + + print(f"\nFlat arrays total: {flat_bytes / 1024 / 1024:.2f} MB") + print(f"Bytes per position: {flat_bytes / flat_result.num_positions:.1f}") + print(f"Bytes per move: {flat_bytes / flat_result.num_moves:.1f}") + + print("\n" + "=" * 60) + print("Benchmark complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/board_serialization.rs b/src/board_serialization.rs new file mode 100644 index 0000000..b8fd021 --- /dev/null +++ b/src/board_serialization.rs @@ -0,0 +1,134 @@ +//! Board state serialization helpers for ParsedGames output. +//! +//! Board encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk +//! This matches python-chess piece_type (1-6) with color offset (+6 for black). +//! +//! Square indexing: a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63 +//! Note: This differs from some Python code that uses [7 - rank, file] indexing. +//! The Python wrapper can transpose if needed. + +use shakmaty::{Chess, Color, EnPassantMode, Position, Role, Square}; + +/// Serialize board position to 64-byte array. +/// Index mapping: square index (a1=0, h8=63) -> piece value (0-12) +pub fn serialize_board(pos: &Chess) -> [u8; 64] { + let mut board = [0u8; 64]; + let b = pos.board(); + + for sq in Square::ALL { + if let Some(piece) = b.piece_at(sq) { + let piece_val = match piece.role { + Role::Pawn => 1, + Role::Knight => 2, + Role::Bishop => 3, + Role::Rook => 4, + Role::Queen => 5, + Role::King => 6, + }; + let color_offset = if piece.color == Color::White { 0 } else { 6 }; + board[sq as usize] = piece_val + color_offset; + } + } + board +} + +/// Get en passant file (0-7) or -1 if none. +/// Uses Always mode to report the e.p. square whenever a double pawn push occurred, +/// regardless of whether a legal capture is available. +pub fn get_en_passant_file(pos: &Chess) -> i8 { + pos.ep_square(EnPassantMode::Always) + .map(|sq| sq.file() as i8) + .unwrap_or(-1) +} + +/// Get halfmove clock (for 50-move rule). +pub fn get_halfmove_clock(pos: &Chess) -> u8 { + pos.halfmoves().min(255) as u8 +} + +/// Get side to move: true = white, false = black. +pub fn get_turn(pos: &Chess) -> bool { + pos.turn() == Color::White +} + +/// Get castling rights as [K, Q, k, q] (white kingside, white queenside, black kingside, black queenside). +pub fn get_castling_rights(pos: &Chess) -> [bool; 4] { + let rights = pos.castles().castling_rights(); + [ + rights.contains(Square::H1), // White kingside (rook on h1) + rights.contains(Square::A1), // White queenside (rook on a1) + rights.contains(Square::H8), // Black kingside (rook on h8) + rights.contains(Square::A8), // Black queenside (rook on a8) + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialize_initial_board() { + let pos = Chess::default(); + let board = serialize_board(&pos); + + // White pieces on rank 1 (indices 0-7) + assert_eq!(board[0], 4); // a1 = white rook + assert_eq!(board[1], 2); // b1 = white knight + assert_eq!(board[2], 3); // c1 = white bishop + assert_eq!(board[3], 5); // d1 = white queen + assert_eq!(board[4], 6); // e1 = white king + assert_eq!(board[5], 3); // f1 = white bishop + assert_eq!(board[6], 2); // g1 = white knight + assert_eq!(board[7], 4); // h1 = white rook + + // White pawns on rank 2 (indices 8-15) + for i in 8..16 { + assert_eq!(board[i], 1); // white pawn + } + + // Empty squares (ranks 3-6, indices 16-47) + for i in 16..48 { + assert_eq!(board[i], 0); + } + + // Black pawns on rank 7 (indices 48-55) + for i in 48..56 { + assert_eq!(board[i], 7); // black pawn (1 + 6) + } + + // Black pieces on rank 8 (indices 56-63) + assert_eq!(board[56], 10); // a8 = black rook (4 + 6) + assert_eq!(board[57], 8); // b8 = black knight (2 + 6) + assert_eq!(board[58], 9); // c8 = black bishop (3 + 6) + assert_eq!(board[59], 11); // d8 = black queen (5 + 6) + assert_eq!(board[60], 12); // e8 = black king (6 + 6) + assert_eq!(board[61], 9); // f8 = black bishop (3 + 6) + assert_eq!(board[62], 8); // g8 = black knight (2 + 6) + assert_eq!(board[63], 10); // h8 = black rook (4 + 6) + } + + #[test] + fn test_initial_castling_rights() { + let pos = Chess::default(); + let rights = get_castling_rights(&pos); + assert_eq!(rights, [true, true, true, true]); // [K, Q, k, q] + } + + #[test] + fn test_initial_en_passant() { + let pos = Chess::default(); + assert_eq!(get_en_passant_file(&pos), -1); + } + + #[test] + fn test_initial_halfmove_clock() { + let pos = Chess::default(); + assert_eq!(get_halfmove_clock(&pos), 0); + } + + #[test] + fn test_initial_turn() { + let pos = Chess::default(); + assert!(get_turn(&pos)); // White to move + } +} diff --git a/src/example_parse_games_flat.py b/src/example_parse_games_flat.py new file mode 100644 index 0000000..58fadd2 --- /dev/null +++ b/src/example_parse_games_flat.py @@ -0,0 +1,178 @@ +""" +Example demonstrating the parse_games_flat() API for ML-optimized PGN parsing. + +This API returns flat NumPy arrays suitable for efficient batch processing +in machine learning pipelines. +""" + +import sys +import os + +import numpy as np +import pyarrow as pa + +import rust_pgn_reader_python_binding as pgn + +# Add python directory to path for wrapper imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) +from wrapper import add_ergonomic_methods, GameView + +# Patch ParsedGames with ergonomic methods (do this once at startup) +add_ergonomic_methods(pgn.ParsedGames) + + +# Sample PGN games with annotations +sample_pgns = [ + # Game 1: Sicilian Defense with eval/clock annotations + """[Event "Online Blitz"] +[White "Player1"] +[Black "Player2"] +[Result "1-0"] +[WhiteElo "1850"] +[BlackElo "1790"] + +1. e4 { [%eval 0.17] [%clk 0:03:00] } 1... c5 { [%eval 0.19] [%clk 0:03:00] } +2. Nf3 { [%eval 0.25] [%clk 0:02:55] } 2... d6 { [%eval 0.30] [%clk 0:02:58] } +3. d4 { [%eval 0.35] [%clk 0:02:50] } 3... cxd4 { [%eval 0.32] [%clk 0:02:55] } +4. Nxd4 { [%eval 0.28] [%clk 0:02:48] } 4... Nf6 { [%eval 0.25] [%clk 0:02:52] } +5. Nc3 { [%eval 0.30] [%clk 0:02:45] } 1-0""", + # Game 2: Italian Game + """[Event "Club Championship"] +[White "Magnus"] +[Black "Hikaru"] +[Result "0-1"] + +1. e4 e5 2. Nf3 Nc6 3. Bc4 Bc5 4. c3 Nf6 5. d4 exd4 6. cxd4 Bb4+ 0-1""", + # Game 3: Scholar's Mate (checkmate) + """[Event "Beginner Game"] +[White "NewPlayer"] +[Black "Victim"] +[Result "1-0"] + +1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0""", + # Game 4: Queen's Gambit + """[Event "Tournament"] +[White "Carlsen"] +[Black "Caruana"] +[Result "1/2-1/2"] + +1. d4 d5 2. c4 e6 3. Nc3 Nf6 4. Bg5 Be7 5. e3 O-O 1/2-1/2""", +] + + +def main(): + print("=" * 60) + print("parse_games_flat() API Example") + print("=" * 60) + + # Create PyArrow chunked array from PGN strings + chunked_array = pa.chunked_array([pa.array(sample_pgns)]) + + # Parse all games at once - returns flat NumPy arrays + result = pgn.parse_games_flat(chunked_array) + + # === Basic Statistics === + print(f"\n--- Basic Statistics ---") + print(f"Number of games: {result.num_games}") + print(f"Total moves: {result.num_moves}") + print(f"Total positions: {result.num_positions}") + print(f"Valid games: {result.valid.sum()} / {result.num_games}") + + # === Array Shapes === + print(f"\n--- Array Shapes ---") + print(f"boards: {result.boards.shape} ({result.boards.dtype})") + print(f"castling: {result.castling.shape} ({result.castling.dtype})") + print(f"en_passant: {result.en_passant.shape} ({result.en_passant.dtype})") + print( + f"from_squares: {result.from_squares.shape} ({result.from_squares.dtype})" + ) + print(f"to_squares: {result.to_squares.shape} ({result.to_squares.dtype})") + print(f"promotions: {result.promotions.shape} ({result.promotions.dtype})") + print(f"clocks: {result.clocks.shape} ({result.clocks.dtype})") + print(f"evals: {result.evals.shape} ({result.evals.dtype})") + print(f"move_offsets: {result.move_offsets.shape}") + print(f"position_offsets: {result.position_offsets.shape}") + + # === Iterate Over Games === + print(f"\n--- Game Details ---") + for i, game in enumerate(result): + print( + f"\nGame {i + 1}: {game.headers.get('White', '?')} vs {game.headers.get('Black', '?')}" + ) + print(f" Event: {game.headers.get('Event', 'N/A')}") + print(f" Moves: {len(game)}") + print(f" Positions: {game.num_positions}") + print(f" Valid: {game.is_valid}") + print(f" Checkmate: {game.is_checkmate}") + print( + f" UCI moves: {' '.join(game.moves_uci()[:10])}{'...' if len(game) > 10 else ''}" + ) + + # Show eval annotations if available + valid_evals = game.evals[~np.isnan(game.evals)] + if len(valid_evals) > 0: + print(f" Evals: {valid_evals[:5].tolist()}...") + + # === Direct Array Access for ML === + print(f"\n--- ML-Ready Data Access ---") + + # Get all board positions as a single tensor + all_boards = result.boards # Shape: (N_positions, 8, 8) + print(f"All boards tensor: {all_boards.shape}") + + # Get initial position of first game + game0 = result[0] + initial_board = game0.initial_board + print(f"\nInitial board (Game 1):") + # Print board with piece symbols + piece_chars = " PNBRQKpnbrqk" + for rank in range(7, -1, -1): # Print from rank 8 to rank 1 + row = "" + for file in range(8): + sq_idx = rank * 8 + file + piece = initial_board.flat[sq_idx] + row += piece_chars[piece] + " " + print(f" {rank + 1} | {row}") + print(f" +----------------") + print(f" a b c d e f g h") + + # === Position-to-Game Mapping === + print(f"\n--- Position-to-Game Mapping ---") + # Useful for shuffling positions while keeping track of game metadata + sample_positions = np.array([0, 5, 10, 15, 20]) + game_indices = result.position_to_game(sample_positions) + print(f"Position indices: {sample_positions}") + print(f"Game indices: {game_indices}") + + # === Slicing === + print(f"\n--- Slicing ---") + # Get games 1-2 (0-indexed) + subset = result[1:3] + print(f"Slice [1:3]: {len(subset)} games") + for game in subset: + print( + f" - {game.headers.get('White', '?')} vs {game.headers.get('Black', '?')}" + ) + + # === Negative Indexing === + print(f"\n--- Negative Indexing ---") + last_game = result[-1] + print( + f"Last game: {last_game.headers.get('White', '?')} vs {last_game.headers.get('Black', '?')}" + ) + print(f" Moves: {last_game.moves_uci()}") + + # === Checkmate Detection === + print(f"\n--- Checkmate Detection ---") + for i, game in enumerate(result): + if game.is_checkmate: + print(f"Game {i + 1} ended in checkmate!") + print(f" Final position legal moves: {game.legal_move_count}") + + print("\n" + "=" * 60) + print("Example complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/lib.rs b/src/lib.rs index d38b337..3eb20f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,26 @@ -use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; +use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; use arrow_array::{Array, LargeStringArray, StringArray}; +use numpy::{PyArray1, PyArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; use pyo3_arrow::PyChunkedArray; -use rayon::ThreadPoolBuilder; use rayon::prelude::*; +use rayon::ThreadPoolBuilder; +use shakmaty::fen::Fen; use shakmaty::CastlingMode; use shakmaty::Color; -use shakmaty::fen::Fen; -use shakmaty::{Chess, Position, Role, Square, uci::UciMove}; +use shakmaty::{uci::UciMove, Chess, Position, Role, Square}; +use std::collections::HashMap; use std::io::Cursor; use std::ops::ControlFlow; +mod board_serialization; mod comment_parsing; +use board_serialization::{ + get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, +}; + // Definition of PyUciMove #[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] #[derive(Clone, Debug)] @@ -142,6 +149,13 @@ pub struct MoveExtractor { position_status: Option, pos: Chess, + + // Board state tracking for flat output (not directly exposed to Python) + board_states: Vec, // Flattened: 64 bytes per position + en_passant_states: Vec, // Per position: -1 or file 0-7 + halfmove_clocks: Vec, // Per position + turn_states: Vec, // Per position: true=white + castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] } #[pymethods] @@ -163,6 +177,11 @@ impl MoveExtractor { headers: Vec::with_capacity(10), castling_rights: Vec::with_capacity(100), position_status: None, + board_states: Vec::with_capacity(100 * 64), + en_passant_states: Vec::with_capacity(100), + halfmove_clocks: Vec::with_capacity(100), + turn_states: Vec::with_capacity(100), + castling_states: Vec::with_capacity(100 * 4), } } @@ -209,6 +228,17 @@ impl MoveExtractor { } } + /// Record current board state to flat arrays for ParsedGames output. + fn push_board_state(&mut self) { + self.board_states + .extend_from_slice(&serialize_board(&self.pos)); + self.en_passant_states.push(get_en_passant_file(&self.pos)); + self.halfmove_clocks.push(get_halfmove_clock(&self.pos)); + self.turn_states.push(get_turn(&self.pos)); + let castling = get_castling_rights(&self.pos); + self.castling_states.extend_from_slice(&castling); + } + fn update_position_status(&mut self) { // TODO this checks legal_moves() a bunch of times self.position_status = Some(PositionStatus { @@ -281,6 +311,11 @@ impl Visitor for MoveExtractor { self.evals.clear(); self.clock_times.clear(); self.castling_rights.clear(); + self.board_states.clear(); + self.en_passant_states.clear(); + self.halfmove_clocks.clear(); + self.turn_states.clear(); + self.castling_states.clear(); // Determine castling mode from Variant header (case-insensitive) let castling_mode = self @@ -328,6 +363,8 @@ impl Visitor for MoveExtractor { if self.store_legal_moves { self.push_legal_moves(); } + // Record initial board state for flat output + self.push_board_state(); ControlFlow::Continue(()) } @@ -345,6 +382,8 @@ impl Visitor for MoveExtractor { if self.store_legal_moves { self.push_legal_moves(); } + // Record board state after move for flat output + self.push_board_state(); let uci_move_obj = UciMove::from_standard(m); match uci_move_obj { @@ -475,6 +514,281 @@ impl Visitor for MoveExtractor { } } +/// Flat array container for parsed chess games, optimized for ML training. +/// +/// # Indexing +/// - `N_games`: Number of games +/// - `N_moves`: Total moves across all games +/// - `N_positions`: Total board positions recorded (varies per game due to initial position + moves) +/// +/// # Board layout +/// Boards use square indexing: a1=0, b1=1, ..., h8=63 +/// Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk +#[pyclass] +pub struct ParsedGames { + // === Board state arrays (N_positions) === + /// Board positions, shape (N_positions, 8, 8), dtype uint8 + #[pyo3(get)] + boards: Py, + + /// Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool + #[pyo3(get)] + castling: Py, + + /// En passant file (-1 if none), shape (N_positions,), dtype int8 + #[pyo3(get)] + en_passant: Py, + + /// Halfmove clock, shape (N_positions,), dtype uint8 + #[pyo3(get)] + halfmove_clock: Py, + + /// Side to move (true=white), shape (N_positions,), dtype bool + #[pyo3(get)] + turn: Py, + + // === Move arrays (N_moves) === + /// From squares, shape (N_moves,), dtype uint8 + #[pyo3(get)] + from_squares: Py, + + /// To squares, shape (N_moves,), dtype uint8 + #[pyo3(get)] + to_squares: Py, + + /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8 + #[pyo3(get)] + promotions: Py, + + /// Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32 + #[pyo3(get)] + clocks: Py, + + /// Engine evals (NaN if missing), shape (N_moves,), dtype float32 + #[pyo3(get)] + evals: Py, + + // === Offsets === + /// Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 + /// Game i's moves: move_offsets[i]..move_offsets[i+1] + #[pyo3(get)] + move_offsets: Py, + + /// Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 + /// Game i's positions: position_offsets[i]..position_offsets[i+1] + #[pyo3(get)] + position_offsets: Py, + + // === Final position status (N_games) === + /// Final position is checkmate, shape (N_games,), dtype bool + #[pyo3(get)] + is_checkmate: Py, + + /// Final position is stalemate, shape (N_games,), dtype bool + #[pyo3(get)] + is_stalemate: Py, + + /// Insufficient material (white, black), shape (N_games, 2), dtype bool + #[pyo3(get)] + is_insufficient: Py, + + /// Legal move count in final position, shape (N_games,), dtype uint16 + #[pyo3(get)] + legal_move_count: Py, + + // === Parse status (N_games) === + /// Whether game parsed successfully, shape (N_games,), dtype bool + #[pyo3(get)] + valid: Py, + + // === Raw headers (N_games) === + /// Raw PGN headers as list of dicts + #[pyo3(get)] + headers: Vec>, +} + +#[pyfunction] +#[pyo3(signature = (pgn_chunked_array, num_threads=None))] +fn parse_games_flat( + py: Python<'_>, + pgn_chunked_array: PyChunkedArray, + num_threads: Option, +) -> PyResult { + // 1. Parse all games using existing logic + let extractors = _parse_game_moves_from_arrow_chunks_native( + &pgn_chunked_array, + num_threads, + false, // store_legal_moves = false for performance + ) + .map_err(|e| PyErr::new::(e))?; + + let n_games = extractors.len(); + + // 2. Compute move counts and position counts from actual recorded data + let move_counts: Vec = extractors.iter().map(|e| e.moves.len() as u32).collect(); + + // Position counts derived from actual board_states data + let position_counts: Vec = extractors + .iter() + .map(|e| (e.board_states.len() / 64) as u32) + .collect(); + + // Build move offsets + let mut move_offsets_vec: Vec = Vec::with_capacity(n_games + 1); + move_offsets_vec.push(0); + for &count in &move_counts { + move_offsets_vec.push(move_offsets_vec.last().unwrap() + count); + } + + // Build position offsets + let mut position_offsets_vec: Vec = Vec::with_capacity(n_games + 1); + position_offsets_vec.push(0); + for &count in &position_counts { + position_offsets_vec.push(position_offsets_vec.last().unwrap() + count); + } + + let total_moves = *move_offsets_vec.last().unwrap() as usize; + let total_positions = *position_offsets_vec.last().unwrap() as usize; + + // 3. Pre-allocate flat vectors + let mut boards_vec: Vec = Vec::with_capacity(total_positions * 64); + let mut castling_vec: Vec = Vec::with_capacity(total_positions * 4); + let mut en_passant_vec: Vec = Vec::with_capacity(total_positions); + let mut halfmove_clock_vec: Vec = Vec::with_capacity(total_positions); + let mut turn_vec: Vec = Vec::with_capacity(total_positions); + + let mut from_squares_vec: Vec = Vec::with_capacity(total_moves); + let mut to_squares_vec: Vec = Vec::with_capacity(total_moves); + let mut promotions_vec: Vec = Vec::with_capacity(total_moves); + let mut clocks_vec: Vec = Vec::with_capacity(total_moves); + let mut evals_vec: Vec = Vec::with_capacity(total_moves); + + let mut is_checkmate_vec: Vec = Vec::with_capacity(n_games); + let mut is_stalemate_vec: Vec = Vec::with_capacity(n_games); + let mut is_insufficient_vec: Vec = Vec::with_capacity(n_games * 2); + let mut legal_move_count_vec: Vec = Vec::with_capacity(n_games); + let mut valid_vec: Vec = Vec::with_capacity(n_games); + let mut headers_vec: Vec> = Vec::with_capacity(n_games); + + // 4. Copy data from each extractor + for extractor in &extractors { + // Board states + boards_vec.extend_from_slice(&extractor.board_states); + castling_vec.extend(extractor.castling_states.iter().copied()); + en_passant_vec.extend_from_slice(&extractor.en_passant_states); + halfmove_clock_vec.extend_from_slice(&extractor.halfmove_clocks); + turn_vec.extend(extractor.turn_states.iter().copied()); + + // Moves + for m in &extractor.moves { + from_squares_vec.push(m.from_square); + to_squares_vec.push(m.to_square); + promotions_vec.push(m.promotion.map(|p| p as i8).unwrap_or(-1)); + } + + // Clocks (convert to seconds) + for clock in &extractor.clock_times { + clocks_vec.push( + clock + .map(|(h, m, s)| h as f32 * 3600.0 + m as f32 * 60.0 + s as f32) + .unwrap_or(f32::NAN), + ); + } + + // Evals (convert mate values to large numbers) + for eval in &extractor.evals { + evals_vec.push(eval.map(|e| e as f32).unwrap_or(f32::NAN)); + } + + // Final position status + if let Some(ref status) = extractor.position_status { + is_checkmate_vec.push(status.is_checkmate); + is_stalemate_vec.push(status.is_stalemate); + is_insufficient_vec.push(status.insufficient_material.0); + is_insufficient_vec.push(status.insufficient_material.1); + legal_move_count_vec.push(status.legal_move_count as u16); + } else { + // No status computed - use defaults + is_checkmate_vec.push(false); + is_stalemate_vec.push(false); + is_insufficient_vec.push(false); + is_insufficient_vec.push(false); + legal_move_count_vec.push(0); + } + + // Valid flag + valid_vec.push(extractor.valid_moves); + + // Headers as HashMap + let header_map: HashMap = extractor + .headers + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + headers_vec.push(header_map); + } + + // 5. Convert to numpy arrays + // Boards: reshape from flat to (N_positions, 8, 8) + let boards_array = PyArray1::from_vec(py, boards_vec); + let boards_reshaped = boards_array + .reshape([total_positions, 8, 8]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Castling: reshape from flat to (N_positions, 4) + let castling_array = PyArray1::from_vec(py, castling_vec); + let castling_reshaped = castling_array + .reshape([total_positions, 4]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + // 1D arrays + let en_passant_array = PyArray1::from_vec(py, en_passant_vec); + let halfmove_clock_array = PyArray1::from_vec(py, halfmove_clock_vec); + let turn_array = PyArray1::from_vec(py, turn_vec); + + let from_squares_array = PyArray1::from_vec(py, from_squares_vec); + let to_squares_array = PyArray1::from_vec(py, to_squares_vec); + let promotions_array = PyArray1::from_vec(py, promotions_vec); + let clocks_array = PyArray1::from_vec(py, clocks_vec); + let evals_array = PyArray1::from_vec(py, evals_vec); + + let move_offsets_array = PyArray1::from_vec(py, move_offsets_vec); + let position_offsets_array = PyArray1::from_vec(py, position_offsets_vec); + + let is_checkmate_array = PyArray1::from_vec(py, is_checkmate_vec); + let is_stalemate_array = PyArray1::from_vec(py, is_stalemate_vec); + + // is_insufficient: reshape to (N_games, 2) + let is_insufficient_array = PyArray1::from_vec(py, is_insufficient_vec); + let is_insufficient_reshaped = is_insufficient_array + .reshape([n_games, 2]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let legal_move_count_array = PyArray1::from_vec(py, legal_move_count_vec); + let valid_array = PyArray1::from_vec(py, valid_vec); + + Ok(ParsedGames { + boards: boards_reshaped.unbind().into_any(), + castling: castling_reshaped.unbind().into_any(), + en_passant: en_passant_array.unbind().into_any(), + halfmove_clock: halfmove_clock_array.unbind().into_any(), + turn: turn_array.unbind().into_any(), + from_squares: from_squares_array.unbind().into_any(), + to_squares: to_squares_array.unbind().into_any(), + promotions: promotions_array.unbind().into_any(), + clocks: clocks_array.unbind().into_any(), + evals: evals_array.unbind().into_any(), + move_offsets: move_offsets_array.unbind().into_any(), + position_offsets: position_offsets_array.unbind().into_any(), + is_checkmate: is_checkmate_array.unbind().into_any(), + is_stalemate: is_stalemate_array.unbind().into_any(), + is_insufficient: is_insufficient_reshaped.unbind().into_any(), + legal_move_count: legal_move_count_array.unbind().into_any(), + valid: valid_array.unbind().into_any(), + headers: headers_vec, + }) +} + // --- Native Rust versions (no PyResult) --- pub fn parse_single_game_native( pgn: &str, @@ -594,9 +908,11 @@ fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_game, m)?)?; m.add_function(wrap_pyfunction!(parse_games, m)?)?; m.add_function(wrap_pyfunction!(parse_game_moves_arrow_chunked_array, m)?)?; + m.add_function(wrap_pyfunction!(parse_games_flat, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/test.py b/src/test.py index a4decf2..592f1e3 100644 --- a/src/test.py +++ b/src/test.py @@ -1,7 +1,15 @@ import unittest +import sys +import os +import numpy as np + import rust_pgn_reader_python_binding import pyarrow as pa +# Add python directory to path for wrapper imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) +from wrapper import add_ergonomic_methods, GameView + class TestPgnExtraction(unittest.TestCase): def run_extractor(self, pgn_string): @@ -680,5 +688,295 @@ def test_parse_game_moves_arrow_chunked_array(self): ) +class TestParsedGamesFlat(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Patch ParsedGames with ergonomic methods once.""" + add_ergonomic_methods(rust_pgn_reader_python_binding.ParsedGames) + + def test_basic_structure(self): + """Test basic flat parsing returns correct structure.""" + pgns = [ + "1. e4 e5 2. Nf3 Nc6 3. Bb5 1-0", + "1. d4 d5 2. c4 e6 0-1", + ] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Check game count + self.assertEqual(len(result), 2) + + # Check move offsets + self.assertEqual(len(result.move_offsets), 3) + self.assertEqual(result.move_offsets[0], 0) + self.assertEqual(result.move_offsets[1], 5) # Game 1: 5 half-moves + self.assertEqual(result.move_offsets[2], 9) # Game 2: 4 half-moves + + # Check shapes + total_moves = 9 + total_positions = 9 + 2 # moves + initial positions + + self.assertEqual(result.boards.shape, (total_positions, 8, 8)) + self.assertEqual(result.castling.shape, (total_positions, 4)) + self.assertEqual(result.en_passant.shape, (total_positions,)) + self.assertEqual(result.from_squares.shape, (total_moves,)) + self.assertEqual(result.valid.shape, (2,)) + + def test_initial_board_encoding(self): + """Test initial board state encoding.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + initial = result.boards[0] # First position + + # Encoding: 0=empty, 1=P, 2=N, 3=B, 4=R, 5=Q, 6=K, +6 for black + # Square indexing: a1=0, b1=1, ..., h1=7, a2=8, ... + + # a1 (index 0) = white rook = 4 + self.assertEqual(initial.flat[0], 4) + # b1 (index 1) = white knight = 2 + self.assertEqual(initial.flat[1], 2) + # e1 (index 4) = white king = 6 + self.assertEqual(initial.flat[4], 6) + # e2 (index 12) = white pawn = 1 + self.assertEqual(initial.flat[12], 1) + # e4 (index 28) = empty = 0 + self.assertEqual(initial.flat[28], 0) + # e7 (index 52) = black pawn = 7 + self.assertEqual(initial.flat[52], 7) + # e8 (index 60) = black king = 12 + self.assertEqual(initial.flat[60], 12) + + def test_board_after_move(self): + """Test board state updates correctly after move.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Position 0: initial, Position 1: after e4 + after_e4 = result.boards[1] + + # e2 (index 12) should be empty + self.assertEqual(after_e4.flat[12], 0) + # e4 (index 28) should have white pawn + self.assertEqual(after_e4.flat[28], 1) + + def test_en_passant_tracking(self): + """Test en passant square tracking.""" + pgns = ["1. e4 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Initial: no en passant + self.assertEqual(result.en_passant[0], -1) + # After e4: en passant on e-file (file index 4) + self.assertEqual(result.en_passant[1], 4) + + def test_castling_rights(self): + """Test castling rights tracking.""" + # White moves rook, losing kingside castling + pgns = ["1. e4 e5 2. Nf3 Nc6 3. Rg1 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Initial: all castling [K, Q, k, q] = [True, True, True, True] + self.assertTrue(all(result.castling[0])) + + # After Rg1 (position 5): white kingside lost + # Castling order: [K, Q, k, q] + self.assertFalse(result.castling[5, 0]) # White K + self.assertTrue(result.castling[5, 1]) # White Q + self.assertTrue(result.castling[5, 2]) # Black k + self.assertTrue(result.castling[5, 3]) # Black q + + def test_turn_tracking(self): + """Test side-to-move tracking.""" + pgns = ["1. e4 e5 2. Nf3 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Initial: white to move + self.assertTrue(result.turn[0]) + # After e4: black to move + self.assertFalse(result.turn[1]) + # After e5: white to move + self.assertTrue(result.turn[2]) + + def test_game_view_access(self): + """Test GameView provides correct slices.""" + pgns = ["1. e4 e5 2. Nf3 1-0", "1. d4 d5 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + game0 = result[0] + self.assertEqual(len(game0), 3) + self.assertEqual(game0.num_positions, 4) + self.assertEqual(game0.boards.shape, (4, 8, 8)) + self.assertTrue(game0.is_valid) + + game1 = result[1] + self.assertEqual(len(game1), 2) + self.assertEqual(game1.num_positions, 3) + + def test_game_view_move_uci(self): + """Test GameView UCI move conversion.""" + pgns = ["1. e4 e5 2. Nf3 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + game = result[0] + self.assertEqual(game.move_uci(0), "e2e4") + self.assertEqual(game.move_uci(1), "e7e5") + self.assertEqual(game.move_uci(2), "g1f3") + + self.assertEqual(game.moves_uci(), ["e2e4", "e7e5", "g1f3"]) + + def test_iteration(self): + """Test iteration over games.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + games = list(result) + self.assertEqual(len(games), 3) + self.assertIsInstance(games[0], GameView) + + def test_slicing(self): + """Test slicing returns BatchSlice.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + sliced = result[1:3] + self.assertEqual(len(sliced), 2) + games = list(sliced) + self.assertEqual(len(games[0]), 1) # d4 game + + def test_position_to_game_mapping(self): + """Test position to game index mapping.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves (3 pos), 1 move (2 pos) + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Positions: 0,1,2 (game 0), 3,4 (game 1) + pos_indices = np.array([0, 1, 2, 3, 4]) + game_indices = result.position_to_game(pos_indices) + + np.testing.assert_array_equal(game_indices, [0, 0, 0, 1, 1]) + + def test_move_to_game_mapping(self): + """Test move to game index mapping.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves, 1 move + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + move_indices = np.array([0, 1, 2]) + game_indices = result.move_to_game(move_indices) + + np.testing.assert_array_equal(game_indices, [0, 0, 1]) + + def test_clocks_and_evals(self): + """Test clock and eval parsing.""" + pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... e5 { [%eval 0.19] [%clk 0:00:29] } 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + self.assertAlmostEqual(result.evals[0], 0.17, places=2) + self.assertAlmostEqual(result.evals[1], 0.19, places=2) + self.assertAlmostEqual(result.clocks[0], 30.0, places=1) + self.assertAlmostEqual(result.clocks[1], 29.0, places=1) + + def test_missing_clocks_evals_are_nan(self): + """Test missing clocks/evals are NaN.""" + pgns = ["1. e4 e5 1-0"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + self.assertTrue(np.isnan(result.clocks[0])) + self.assertTrue(np.isnan(result.evals[0])) + + def test_headers_preserved(self): + """Test headers are preserved as dicts.""" + pgn = """[White "Player1"] +[Black "Player2"] +[WhiteElo "1500"] + +1. e4 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + game = result[0] + self.assertEqual(game.headers["White"], "Player1") + self.assertEqual(game.headers["Black"], "Player2") + self.assertEqual(game.headers["WhiteElo"], "1500") + + def test_invalid_game_flagged(self): + """Test invalid games are flagged but don't break structure.""" + pgns = [ + "1. e4 e5 1-0", # Valid + "1. e4 Qxd7 1-0", # Invalid move + "1. d4 d5 0-1", # Valid + ] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + self.assertEqual(len(result), 3) + self.assertTrue(result.valid[0]) + self.assertFalse(result.valid[1]) + self.assertTrue(result.valid[2]) + + def test_checkmate_detection(self): + """Test checkmate is detected.""" + # Scholar's mate + pgn = "1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + game = result[0] + self.assertTrue(game.is_checkmate) + self.assertFalse(game.is_stalemate) + self.assertEqual(game.legal_move_count, 0) + + def test_promotion(self): + """Test promotion encoding.""" + # Simplified position reaching promotion + pgn = """[FEN "8/P7/8/8/8/8/8/4K2k w - - 0 1"] + +1. a8=Q 1-0""" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Promotion to queen = 5 + self.assertEqual(result.promotions[0], 5) + + game = result[0] + self.assertEqual(game.move_uci(0), "a7a8q") + + def test_num_properties(self): + """Test num_games, num_moves, num_positions properties.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 + 1 = 3 moves, 3 + 2 = 5 positions + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + self.assertEqual(result.num_games, 2) + self.assertEqual(result.num_moves, 3) + self.assertEqual(result.num_positions, 5) + + def test_negative_indexing(self): + """Test negative index access.""" + pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # -1 should be the last game + last_game = result[-1] + self.assertEqual(len(last_game), 1) + + # Check that from_squares for c4 is c2 + # c2 = file 2 (c) + rank 1 (2nd rank) * 8 = 2 + 8 = 10 + self.assertEqual(last_game.from_squares[0], 10) # c2 + + if __name__ == "__main__": unittest.main() From de6b9afd6cecbf5b9052d582531bbf73d3bc0aa4 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:08:55 -0500 Subject: [PATCH 02/31] Move the wrapper from Python into Rust --- rust_pgn_reader_python_binding.pyi | 339 ++++++++++++++++++- src/example_parse_games_flat.py | 12 +- src/lib.rs | 506 ++++++++++++++++++++++++++++- 3 files changed, 844 insertions(+), 13 deletions(-) diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index d4745eb..86c4166 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -1,5 +1,7 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Iterator, Union, overload import pyarrow +import numpy as np +from numpy.typing import NDArray class PyUciMove: from_square: int @@ -43,6 +45,324 @@ class MoveExtractor: @property def legal_moves(self) -> List[List[PyUciMove]]: ... +class PyGameView: + """Zero-copy view into a single game's data within a ParsedGames result. + + Board indexing note: Boards use square indexing (a1=0, h8=63). + To convert to rank/file: + rank = square // 8 + file = square % 8 + """ + + def __len__(self) -> int: + """Number of moves in this game.""" + ... + + @property + def num_positions(self) -> int: + """Number of positions recorded for this game.""" + ... + + # === Board state views === + + @property + def boards(self) -> NDArray[np.uint8]: + """Board positions, shape (num_positions, 8, 8).""" + ... + + @property + def initial_board(self) -> NDArray[np.uint8]: + """Initial board position, shape (8, 8).""" + ... + + @property + def final_board(self) -> NDArray[np.uint8]: + """Final board position, shape (8, 8).""" + ... + + @property + def castling(self) -> NDArray[np.bool_]: + """Castling rights [K,Q,k,q], shape (num_positions, 4).""" + ... + + @property + def en_passant(self) -> NDArray[np.int8]: + """En passant file (-1 if none), shape (num_positions,).""" + ... + + @property + def halfmove_clock(self) -> NDArray[np.uint8]: + """Halfmove clock, shape (num_positions,).""" + ... + + @property + def turn(self) -> NDArray[np.bool_]: + """Side to move (True=white), shape (num_positions,).""" + ... + + # === Move views === + + @property + def from_squares(self) -> NDArray[np.uint8]: + """From squares, shape (num_moves,).""" + ... + + @property + def to_squares(self) -> NDArray[np.uint8]: + """To squares, shape (num_moves,).""" + ... + + @property + def promotions(self) -> NDArray[np.int8]: + """Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,).""" + ... + + @property + def clocks(self) -> NDArray[np.float32]: + """Clock times in seconds (NaN if missing), shape (num_moves,).""" + ... + + @property + def evals(self) -> NDArray[np.float32]: + """Engine evals (NaN if missing), shape (num_moves,).""" + ... + + # === Per-game metadata === + + @property + def headers(self) -> Dict[str, str]: + """Raw PGN headers as dict.""" + ... + + @property + def is_checkmate(self) -> bool: + """Final position is checkmate.""" + ... + + @property + def is_stalemate(self) -> bool: + """Final position is stalemate.""" + ... + + @property + def is_insufficient(self) -> Tuple[bool, bool]: + """Insufficient material (white, black).""" + ... + + @property + def legal_move_count(self) -> int: + """Legal move count in final position.""" + ... + + @property + def is_valid(self) -> bool: + """Whether game parsed successfully.""" + ... + + # === Convenience methods === + + def move_uci(self, move_idx: int) -> str: + """Get UCI string for move at index.""" + ... + + def moves_uci(self) -> List[str]: + """Get all moves as UCI strings.""" + ... + + def __repr__(self) -> str: ... + +class ParsedGamesIter: + """Iterator over games in a ParsedGames result.""" + + def __iter__(self) -> "ParsedGamesIter": ... + def __next__(self) -> PyGameView: ... + +class ParsedGames: + """Flat array container for parsed chess games, optimized for ML training. + + Indexing: + - N_games: Number of games + - N_moves: Total moves across all games + - N_positions: Total board positions recorded + + Board layout: + Boards use square indexing: a1=0, b1=1, ..., h8=63 + Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk + """ + + # === Board state arrays (N_positions) === + + @property + def boards(self) -> NDArray[np.uint8]: + """Board positions, shape (N_positions, 8, 8), dtype uint8.""" + ... + + @property + def castling(self) -> NDArray[np.bool_]: + """Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool.""" + ... + + @property + def en_passant(self) -> NDArray[np.int8]: + """En passant file (-1 if none), shape (N_positions,), dtype int8.""" + ... + + @property + def halfmove_clock(self) -> NDArray[np.uint8]: + """Halfmove clock, shape (N_positions,), dtype uint8.""" + ... + + @property + def turn(self) -> NDArray[np.bool_]: + """Side to move (True=white), shape (N_positions,), dtype bool.""" + ... + + # === Move arrays (N_moves) === + + @property + def from_squares(self) -> NDArray[np.uint8]: + """From squares, shape (N_moves,), dtype uint8.""" + ... + + @property + def to_squares(self) -> NDArray[np.uint8]: + """To squares, shape (N_moves,), dtype uint8.""" + ... + + @property + def promotions(self) -> NDArray[np.int8]: + """Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8.""" + ... + + @property + def clocks(self) -> NDArray[np.float32]: + """Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32.""" + ... + + @property + def evals(self) -> NDArray[np.float32]: + """Engine evals (NaN if missing), shape (N_moves,), dtype float32.""" + ... + + # === Offsets === + + @property + def move_offsets(self) -> NDArray[np.uint32]: + """Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32. + + Game i's moves: move_offsets[i]..move_offsets[i+1] + """ + ... + + @property + def position_offsets(self) -> NDArray[np.uint32]: + """Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32. + + Game i's positions: position_offsets[i]..position_offsets[i+1] + """ + ... + + # === Final position status (N_games) === + + @property + def is_checkmate(self) -> NDArray[np.bool_]: + """Final position is checkmate, shape (N_games,), dtype bool.""" + ... + + @property + def is_stalemate(self) -> NDArray[np.bool_]: + """Final position is stalemate, shape (N_games,), dtype bool.""" + ... + + @property + def is_insufficient(self) -> NDArray[np.bool_]: + """Insufficient material (white, black), shape (N_games, 2), dtype bool.""" + ... + + @property + def legal_move_count(self) -> NDArray[np.uint16]: + """Legal move count in final position, shape (N_games,), dtype uint16.""" + ... + + # === Parse status (N_games) === + + @property + def valid(self) -> NDArray[np.bool_]: + """Whether game parsed successfully, shape (N_games,), dtype bool.""" + ... + + # === Raw headers (N_games) === + + @property + def headers(self) -> List[Dict[str, str]]: + """Raw PGN headers as list of dicts.""" + ... + + # === Computed properties === + + @property + def num_games(self) -> int: + """Number of games in the result.""" + ... + + @property + def num_moves(self) -> int: + """Total number of moves across all games.""" + ... + + @property + def num_positions(self) -> int: + """Total number of board positions recorded.""" + ... + + # === Sequence protocol === + + def __len__(self) -> int: + """Number of games in the result.""" + ... + + @overload + def __getitem__(self, idx: int) -> PyGameView: ... + @overload + def __getitem__(self, idx: slice) -> List[PyGameView]: ... + def __getitem__( + self, idx: Union[int, slice] + ) -> Union[PyGameView, List[PyGameView]]: + """Access game(s) by index or slice.""" + ... + + def __iter__(self) -> ParsedGamesIter: + """Iterate over all games.""" + ... + + # === Mapping utilities === + + def position_to_game( + self, position_indices: NDArray[np.int64] + ) -> NDArray[np.int64]: + """Map position indices to game indices. + + Useful after shuffling/sampling positions to look up game metadata. + + Args: + position_indices: Array of indices into boards array + + Returns: + Array of game indices (same shape as input) + """ + ... + + def move_to_game(self, move_indices: NDArray[np.int64]) -> NDArray[np.int64]: + """Map move indices to game indices. + + Args: + move_indices: Array of indices into from_squares, to_squares, etc. + + Returns: + Array of game indices (same shape as input) + """ + ... + def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ... def parse_games( pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False @@ -52,3 +372,20 @@ def parse_game_moves_arrow_chunked_array( num_threads: Optional[int] = None, store_legal_moves: bool = False, ) -> List[MoveExtractor]: ... +def parse_games_flat( + pgn_chunked_array: pyarrow.ChunkedArray, + num_threads: Optional[int] = None, +) -> ParsedGames: + """Parse chess games from a PyArrow ChunkedArray into flat NumPy arrays. + + This API is optimized for ML training pipelines, returning flat NumPy arrays + that can be efficiently batched and processed. + + Args: + pgn_chunked_array: PyArrow ChunkedArray containing PGN strings + num_threads: Number of threads for parallel parsing (default: all CPUs) + + Returns: + ParsedGames object containing flat arrays and iteration support + """ + ... diff --git a/src/example_parse_games_flat.py b/src/example_parse_games_flat.py index 58fadd2..61c5b36 100644 --- a/src/example_parse_games_flat.py +++ b/src/example_parse_games_flat.py @@ -5,21 +5,11 @@ in machine learning pipelines. """ -import sys -import os - import numpy as np import pyarrow as pa import rust_pgn_reader_python_binding as pgn -# Add python directory to path for wrapper imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) -from wrapper import add_ergonomic_methods, GameView - -# Patch ParsedGames with ergonomic methods (do this once at startup) -add_ergonomic_methods(pgn.ParsedGames) - # Sample PGN games with annotations sample_pgns = [ @@ -139,7 +129,7 @@ def main(): # === Position-to-Game Mapping === print(f"\n--- Position-to-Game Mapping ---") # Useful for shuffling positions while keeping track of game metadata - sample_positions = np.array([0, 5, 10, 15, 20]) + sample_positions = np.array([0, 5, 10, 15, 20], dtype=np.int64) game_indices = result.position_to_game(sample_positions) print(f"Position indices: {sample_positions}") print(f"Game indices: {game_indices}") diff --git a/src/lib.rs b/src/lib.rs index 3eb20f3..c2e4683 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,9 @@ use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; use arrow_array::{Array, LargeStringArray, StringArray}; -use numpy::{PyArray1, PyArrayMethods}; +use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; +use pyo3::types::PySlice; use pyo3_arrow::PyChunkedArray; use rayon::prelude::*; use rayon::ThreadPoolBuilder; @@ -607,6 +608,507 @@ pub struct ParsedGames { headers: Vec>, } +#[pymethods] +impl ParsedGames { + /// Number of games in the result. + #[getter] + fn num_games(&self) -> usize { + self.headers.len() + } + + /// Total number of moves across all games. + #[getter] + fn num_moves(&self, py: Python<'_>) -> PyResult { + let offsets = self.move_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let len = offsets.len(); + if len == 0 { + return Ok(0); + } + // SAFETY: We just checked len > 0 + let last = unsafe { *offsets.uget([len - 1]) }; + Ok(last as usize) + } + + /// Total number of board positions recorded. + #[getter] + fn num_positions(&self, py: Python<'_>) -> PyResult { + let offsets = self.position_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let len = offsets.len(); + if len == 0 { + return Ok(0); + } + // SAFETY: We just checked len > 0 + let last = unsafe { *offsets.uget([len - 1]) }; + Ok(last as usize) + } + + fn __len__(&self) -> usize { + self.headers.len() + } + + fn __getitem__(slf: Py, py: Python<'_>, idx: &Bound<'_, PyAny>) -> PyResult> { + let n_games = slf.borrow(py).headers.len(); + + // Handle integer index + if let Ok(mut i) = idx.extract::() { + // Handle negative indexing + if i < 0 { + i += n_games as isize; + } + if i < 0 || i >= n_games as isize { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Game index {} out of range [0, {})", + i, n_games + ))); + } + let game_view = PyGameView::new(py, slf.clone_ref(py), i as usize)?; + return Ok(Py::new(py, game_view)?.into_any()); + } + + // Handle slice + if let Ok(slice) = idx.downcast::() { + let indices = slice.indices(n_games as isize)?; + let start = indices.start as usize; + let stop = indices.stop as usize; + let step = indices.step as usize; + + // For simplicity, we return a list of PyGameView objects + let mut views: Vec> = Vec::new(); + let mut i = start; + while i < stop { + let game_view = PyGameView::new(py, slf.clone_ref(py), i)?; + views.push(Py::new(py, game_view)?); + i += step; + } + return Ok(pyo3::types::PyList::new(py, views)?.into_any().unbind()); + } + + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Invalid index type: expected int or slice, got {}", + idx.get_type().name()? + ))) + } + + fn __iter__(slf: Py, py: Python<'_>) -> PyResult { + let n_games = slf.borrow(py).headers.len(); + Ok(ParsedGamesIter { + data: slf, + index: 0, + length: n_games, + }) + } + + /// Map position indices to game indices. + /// + /// Useful after shuffling/sampling positions to look up game metadata. + /// + /// Args: + /// position_indices: Array of indices into boards array + /// + /// Returns: + /// Array of game indices (same shape as input) + fn position_to_game<'py>( + &self, + py: Python<'py>, + position_indices: &Bound<'py, PyArray1>, + ) -> PyResult>> { + let offsets = self.position_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + + // Get numpy module for searchsorted + let numpy = py.import("numpy")?; + + // offsets[:-1] - all but last element + let len = offsets.len(); + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + // searchsorted(offsets[:-1], position_indices, side='right') - 1 + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + position_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + // Subtract 1 + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } + + /// Map move indices to game indices. + /// + /// Args: + /// move_indices: Array of indices into from_squares, to_squares, etc. + /// + /// Returns: + /// Array of game indices (same shape as input) + fn move_to_game<'py>( + &self, + py: Python<'py>, + move_indices: &Bound<'py, PyArray1>, + ) -> PyResult>> { + let offsets = self.move_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + + let numpy = py.import("numpy")?; + let len = offsets.len(); + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + move_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } +} + +/// Iterator over games in a ParsedGames result. +#[pyclass] +pub struct ParsedGamesIter { + data: Py, + index: usize, + length: usize, +} + +#[pymethods] +impl ParsedGamesIter { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>, py: Python<'_>) -> PyResult> { + if slf.index >= slf.length { + return Ok(None); + } + let game_view = PyGameView::new(py, slf.data.clone_ref(py), slf.index)?; + slf.index += 1; + Ok(Some(game_view)) + } +} + +/// Zero-copy view into a single game's data within a ParsedGames result. +/// +/// Board indexing note: Boards use square indexing (a1=0, h8=63). +/// To convert to rank/file: +/// rank = square // 8 +/// file = square % 8 +#[pyclass] +pub struct PyGameView { + data: Py, + idx: usize, + move_start: usize, + move_end: usize, + pos_start: usize, + pos_end: usize, +} + +impl PyGameView { + fn new(py: Python<'_>, data: Py, idx: usize) -> PyResult { + let borrowed = data.borrow(py); + + let move_offsets = borrowed.move_offsets.bind(py); + let move_offsets: &Bound<'_, PyArray1> = move_offsets.downcast()?; + let pos_offsets = borrowed.position_offsets.bind(py); + let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.downcast()?; + + // SAFETY: idx is validated by caller, and idx+1 is within bounds due to offset array structure + let move_start = unsafe { *move_offsets.uget([idx]) } as usize; + let move_end = unsafe { *move_offsets.uget([idx + 1]) } as usize; + let pos_start = unsafe { *pos_offsets.uget([idx]) } as usize; + let pos_end = unsafe { *pos_offsets.uget([idx + 1]) } as usize; + + drop(borrowed); + + Ok(Self { + data, + idx, + move_start, + move_end, + pos_start, + pos_end, + }) + } +} + +#[pymethods] +impl PyGameView { + /// Number of moves in this game. + fn __len__(&self) -> usize { + self.move_end - self.move_start + } + + /// Number of positions recorded for this game. + #[getter] + fn num_positions(&self) -> usize { + self.pos_end - self.pos_start + } + + // === Board state views === + + /// Board positions, shape (num_positions, 8, 8). + #[getter] + fn boards<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = boards.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Initial board position, shape (8, 8). + #[getter] + fn initial_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_start,))?; + Ok(slice.unbind()) + } + + /// Final board position, shape (8, 8). + #[getter] + fn final_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_end - 1,))?; + Ok(slice.unbind()) + } + + /// Castling rights [K,Q,k,q], shape (num_positions, 4). + #[getter] + fn castling<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.castling.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// En passant file (-1 if none), shape (num_positions,). + #[getter] + fn en_passant<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.en_passant.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Halfmove clock, shape (num_positions,). + #[getter] + fn halfmove_clock<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.halfmove_clock.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Side to move (True=white), shape (num_positions,). + #[getter] + fn turn<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.turn.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Move views === + + /// From squares, shape (num_moves,). + #[getter] + fn from_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.from_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// To squares, shape (num_moves,). + #[getter] + fn to_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.to_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,). + #[getter] + fn promotions<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.promotions.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Clock times in seconds (NaN if missing), shape (num_moves,). + #[getter] + fn clocks<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.clocks.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Engine evals (NaN if missing), shape (num_moves,). + #[getter] + fn evals<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.evals.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Per-game metadata === + + /// Raw PGN headers as dict. + #[getter] + fn headers(&self, py: Python<'_>) -> PyResult> { + let borrowed = self.data.borrow(py); + Ok(borrowed.headers[self.idx].clone()) + } + + /// Final position is checkmate. + #[getter] + fn is_checkmate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_checkmate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.downcast()?; + // SAFETY: idx is validated during construction + Ok(unsafe { *arr.uget([self.idx]) }) + } + + /// Final position is stalemate. + #[getter] + fn is_stalemate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_stalemate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.downcast()?; + Ok(unsafe { *arr.uget([self.idx]) }) + } + + /// Insufficient material (white, black). + #[getter] + fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_insufficient.bind(py); + let arr: &Bound<'_, PyArray2> = arr.downcast()?; + // SAFETY: idx is validated during construction + let white = unsafe { *arr.uget([self.idx, 0]) }; + let black = unsafe { *arr.uget([self.idx, 1]) }; + Ok((white, black)) + } + + /// Legal move count in final position. + #[getter] + fn legal_move_count(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.legal_move_count.bind(py); + let arr: &Bound<'_, PyArray1> = arr.downcast()?; + Ok(unsafe { *arr.uget([self.idx]) }) + } + + /// Whether game parsed successfully. + #[getter] + fn is_valid(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.valid.bind(py); + let arr: &Bound<'_, PyArray1> = arr.downcast()?; + Ok(unsafe { *arr.uget([self.idx]) }) + } + + // === Convenience methods === + + /// Get UCI string for move at index. + fn move_uci(&self, py: Python<'_>, move_idx: usize) -> PyResult { + if move_idx >= self.move_end - self.move_start { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Move index {} out of range [0, {})", + move_idx, + self.move_end - self.move_start + ))); + } + + let borrowed = self.data.borrow(py); + let from_arr = borrowed.from_squares.bind(py); + let from_arr: &Bound<'_, PyArray1> = from_arr.downcast()?; + let to_arr = borrowed.to_squares.bind(py); + let to_arr: &Bound<'_, PyArray1> = to_arr.downcast()?; + let promo_arr = borrowed.promotions.bind(py); + let promo_arr: &Bound<'_, PyArray1> = promo_arr.downcast()?; + + let abs_idx = self.move_start + move_idx; + // SAFETY: we validated move_idx above and abs_idx is within bounds + let from_sq = unsafe { *from_arr.uget([abs_idx]) }; + let to_sq = unsafe { *to_arr.uget([abs_idx]) }; + let promo = unsafe { *promo_arr.uget([abs_idx]) }; + + let files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']; + let ranks = ['1', '2', '3', '4', '5', '6', '7', '8']; + + let mut uci = format!( + "{}{}{}{}", + files[(from_sq % 8) as usize], + ranks[(from_sq / 8) as usize], + files[(to_sq % 8) as usize], + ranks[(to_sq / 8) as usize] + ); + + if promo >= 0 { + let promo_chars = ['_', '_', 'n', 'b', 'r', 'q']; // 2=N, 3=B, 4=R, 5=Q + if (promo as usize) < promo_chars.len() { + uci.push(promo_chars[promo as usize]); + } + } + + Ok(uci) + } + + /// Get all moves as UCI strings. + fn moves_uci(&self, py: Python<'_>) -> PyResult> { + let n_moves = self.move_end - self.move_start; + let mut result = Vec::with_capacity(n_moves); + for i in 0..n_moves { + result.push(self.move_uci(py, i)?); + } + Ok(result) + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + let headers = self.headers(py)?; + let white = headers.get("White").map(|s| s.as_str()).unwrap_or("?"); + let black = headers.get("Black").map(|s| s.as_str()).unwrap_or("?"); + let n_moves = self.move_end - self.move_start; + let is_valid = self.is_valid(py)?; + Ok(format!( + "", + white, black, n_moves, is_valid + )) + } +} + #[pyfunction] #[pyo3(signature = (pgn_chunked_array, num_threads=None))] fn parse_games_flat( @@ -913,6 +1415,8 @@ fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } From 9ad6e925bcad6d840cbfdd05fc82dcf19f8a4d31 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:15:32 -0500 Subject: [PATCH 03/31] Fix casts, remove wrapper from bench --- src/bench_parse_games_flat.py | 8 -------- src/lib.rs | 38 +++++++++++++++++------------------ 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games_flat.py index 10c0f9d..b5b848a 100644 --- a/src/bench_parse_games_flat.py +++ b/src/bench_parse_games_flat.py @@ -13,7 +13,6 @@ """ import sys -import os import time import argparse from typing import Optional @@ -23,13 +22,6 @@ import rust_pgn_reader_python_binding as pgn -# Add python directory to path for wrapper imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) -from wrapper import add_ergonomic_methods - -# Patch ParsedGames with ergonomic methods -add_ergonomic_methods(pgn.ParsedGames) - def generate_synthetic_pgns(num_games: int, moves_per_game: int = 40) -> list[str]: """Generate synthetic PGN games for benchmarking.""" diff --git a/src/lib.rs b/src/lib.rs index c2e4683..617be60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,16 @@ -use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; +use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; use arrow_array::{Array, LargeStringArray, StringArray}; use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; use pyo3::types::PySlice; use pyo3_arrow::PyChunkedArray; -use rayon::prelude::*; use rayon::ThreadPoolBuilder; -use shakmaty::fen::Fen; +use rayon::prelude::*; use shakmaty::CastlingMode; use shakmaty::Color; -use shakmaty::{uci::UciMove, Chess, Position, Role, Square}; +use shakmaty::fen::Fen; +use shakmaty::{Chess, Position, Role, Square, uci::UciMove}; use std::collections::HashMap; use std::io::Cursor; use std::ops::ControlFlow; @@ -620,7 +620,7 @@ impl ParsedGames { #[getter] fn num_moves(&self, py: Python<'_>) -> PyResult { let offsets = self.move_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; let len = offsets.len(); if len == 0 { return Ok(0); @@ -634,7 +634,7 @@ impl ParsedGames { #[getter] fn num_positions(&self, py: Python<'_>) -> PyResult { let offsets = self.position_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; let len = offsets.len(); if len == 0 { return Ok(0); @@ -668,7 +668,7 @@ impl ParsedGames { } // Handle slice - if let Ok(slice) = idx.downcast::() { + if let Ok(slice) = idx.cast::() { let indices = slice.indices(n_games as isize)?; let start = indices.start as usize; let stop = indices.stop as usize; @@ -715,7 +715,7 @@ impl ParsedGames { position_indices: &Bound<'py, PyArray1>, ) -> PyResult>> { let offsets = self.position_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; // Get numpy module for searchsorted let numpy = py.import("numpy")?; @@ -755,7 +755,7 @@ impl ParsedGames { move_indices: &Bound<'py, PyArray1>, ) -> PyResult>> { let offsets = self.move_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.downcast()?; + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; let numpy = py.import("numpy")?; let len = offsets.len(); @@ -823,9 +823,9 @@ impl PyGameView { let borrowed = data.borrow(py); let move_offsets = borrowed.move_offsets.bind(py); - let move_offsets: &Bound<'_, PyArray1> = move_offsets.downcast()?; + let move_offsets: &Bound<'_, PyArray1> = move_offsets.cast()?; let pos_offsets = borrowed.position_offsets.bind(py); - let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.downcast()?; + let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; // SAFETY: idx is validated by caller, and idx+1 is within bounds due to offset array structure let move_start = unsafe { *move_offsets.uget([idx]) } as usize; @@ -995,7 +995,7 @@ impl PyGameView { fn is_checkmate(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); let arr = borrowed.is_checkmate.bind(py); - let arr: &Bound<'_, PyArray1> = arr.downcast()?; + let arr: &Bound<'_, PyArray1> = arr.cast()?; // SAFETY: idx is validated during construction Ok(unsafe { *arr.uget([self.idx]) }) } @@ -1005,7 +1005,7 @@ impl PyGameView { fn is_stalemate(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); let arr = borrowed.is_stalemate.bind(py); - let arr: &Bound<'_, PyArray1> = arr.downcast()?; + let arr: &Bound<'_, PyArray1> = arr.cast()?; Ok(unsafe { *arr.uget([self.idx]) }) } @@ -1014,7 +1014,7 @@ impl PyGameView { fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { let borrowed = self.data.borrow(py); let arr = borrowed.is_insufficient.bind(py); - let arr: &Bound<'_, PyArray2> = arr.downcast()?; + let arr: &Bound<'_, PyArray2> = arr.cast()?; // SAFETY: idx is validated during construction let white = unsafe { *arr.uget([self.idx, 0]) }; let black = unsafe { *arr.uget([self.idx, 1]) }; @@ -1026,7 +1026,7 @@ impl PyGameView { fn legal_move_count(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); let arr = borrowed.legal_move_count.bind(py); - let arr: &Bound<'_, PyArray1> = arr.downcast()?; + let arr: &Bound<'_, PyArray1> = arr.cast()?; Ok(unsafe { *arr.uget([self.idx]) }) } @@ -1035,7 +1035,7 @@ impl PyGameView { fn is_valid(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); let arr = borrowed.valid.bind(py); - let arr: &Bound<'_, PyArray1> = arr.downcast()?; + let arr: &Bound<'_, PyArray1> = arr.cast()?; Ok(unsafe { *arr.uget([self.idx]) }) } @@ -1053,11 +1053,11 @@ impl PyGameView { let borrowed = self.data.borrow(py); let from_arr = borrowed.from_squares.bind(py); - let from_arr: &Bound<'_, PyArray1> = from_arr.downcast()?; + let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; let to_arr = borrowed.to_squares.bind(py); - let to_arr: &Bound<'_, PyArray1> = to_arr.downcast()?; + let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; let promo_arr = borrowed.promotions.bind(py); - let promo_arr: &Bound<'_, PyArray1> = promo_arr.downcast()?; + let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; let abs_idx = self.move_start + move_idx; // SAFETY: we validated move_idx above and abs_idx is within bounds From 2b662e72ef988836160c80d33aa949333f18a1af Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:16:16 -0500 Subject: [PATCH 04/31] Fix typing in test.py --- src/test.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/test.py b/src/test.py index 592f1e3..32abe36 100644 --- a/src/test.py +++ b/src/test.py @@ -4,11 +4,9 @@ import numpy as np import rust_pgn_reader_python_binding -import pyarrow as pa +from rust_pgn_reader_python_binding import PyGameView -# Add python directory to path for wrapper imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python")) -from wrapper import add_ergonomic_methods, GameView +import pyarrow as pa class TestPgnExtraction(unittest.TestCase): @@ -689,11 +687,6 @@ def test_parse_game_moves_arrow_chunked_array(self): class TestParsedGamesFlat(unittest.TestCase): - @classmethod - def setUpClass(cls): - """Patch ParsedGames with ergonomic methods once.""" - add_ergonomic_methods(rust_pgn_reader_python_binding.ParsedGames) - def test_basic_structure(self): """Test basic flat parsing returns correct structure.""" pgns = [ @@ -840,7 +833,7 @@ def test_iteration(self): games = list(result) self.assertEqual(len(games), 3) - self.assertIsInstance(games[0], GameView) + self.assertIsInstance(games[0], PyGameView) def test_slicing(self): """Test slicing returns BatchSlice.""" From e87c9040ec360f421a84d11f93ee236076928730 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:24:05 -0500 Subject: [PATCH 05/31] Remove unsafe's --- src/lib.rs | 126 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 87 insertions(+), 39 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 617be60..612de6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,16 @@ -use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; +use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; use arrow_array::{Array, LargeStringArray, StringArray}; use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; use pyo3::types::PySlice; use pyo3_arrow::PyChunkedArray; -use rayon::ThreadPoolBuilder; use rayon::prelude::*; +use rayon::ThreadPoolBuilder; +use shakmaty::fen::Fen; use shakmaty::CastlingMode; use shakmaty::Color; -use shakmaty::fen::Fen; -use shakmaty::{Chess, Position, Role, Square, uci::UciMove}; +use shakmaty::{uci::UciMove, Chess, Position, Role, Square}; use std::collections::HashMap; use std::io::Cursor; use std::ops::ControlFlow; @@ -621,13 +621,9 @@ impl ParsedGames { fn num_moves(&self, py: Python<'_>) -> PyResult { let offsets = self.move_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let len = offsets.len(); - if len == 0 { - return Ok(0); - } - // SAFETY: We just checked len > 0 - let last = unsafe { *offsets.uget([len - 1]) }; - Ok(last as usize) + let readonly = offsets.readonly(); + let slice = readonly.as_slice()?; + Ok(slice.last().copied().unwrap_or(0) as usize) } /// Total number of board positions recorded. @@ -635,13 +631,9 @@ impl ParsedGames { fn num_positions(&self, py: Python<'_>) -> PyResult { let offsets = self.position_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let len = offsets.len(); - if len == 0 { - return Ok(0); - } - // SAFETY: We just checked len > 0 - let last = unsafe { *offsets.uget([len - 1]) }; - Ok(last as usize) + let readonly = offsets.readonly(); + let slice = readonly.as_slice()?; + Ok(slice.last().copied().unwrap_or(0) as usize) } fn __len__(&self) -> usize { @@ -827,21 +819,33 @@ impl PyGameView { let pos_offsets = borrowed.position_offsets.bind(py); let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; - // SAFETY: idx is validated by caller, and idx+1 is within bounds due to offset array structure - let move_start = unsafe { *move_offsets.uget([idx]) } as usize; - let move_end = unsafe { *move_offsets.uget([idx + 1]) } as usize; - let pos_start = unsafe { *pos_offsets.uget([idx]) } as usize; - let pos_end = unsafe { *pos_offsets.uget([idx + 1]) } as usize; + let move_offsets_ro = move_offsets.readonly(); + let move_offsets_slice = move_offsets_ro.as_slice()?; + let pos_offsets_ro = pos_offsets.readonly(); + let pos_offsets_slice = pos_offsets_ro.as_slice()?; + + let move_start = move_offsets_slice + .get(idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let move_end = move_offsets_slice + .get(idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_start = pos_offsets_slice + .get(idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_end = pos_offsets_slice + .get(idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; drop(borrowed); Ok(Self { data, idx, - move_start, - move_end, - pos_start, - pos_end, + move_start: *move_start as usize, + move_end: *move_end as usize, + pos_start: *pos_start as usize, + pos_end: *pos_end as usize, }) } } @@ -996,8 +1000,12 @@ impl PyGameView { let borrowed = self.data.borrow(py); let arr = borrowed.is_checkmate.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; - // SAFETY: idx is validated during construction - Ok(unsafe { *arr.uget([self.idx]) }) + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } /// Final position is stalemate. @@ -1006,7 +1014,12 @@ impl PyGameView { let borrowed = self.data.borrow(py); let arr = borrowed.is_stalemate.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; - Ok(unsafe { *arr.uget([self.idx]) }) + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } /// Insufficient material (white, black). @@ -1015,9 +1028,18 @@ impl PyGameView { let borrowed = self.data.borrow(py); let arr = borrowed.is_insufficient.bind(py); let arr: &Bound<'_, PyArray2> = arr.cast()?; - // SAFETY: idx is validated during construction - let white = unsafe { *arr.uget([self.idx, 0]) }; - let black = unsafe { *arr.uget([self.idx, 1]) }; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + // Array is shape (n_games, 2), so index is idx * 2 for white, idx * 2 + 1 for black + let base = self.idx * 2; + let white = slice + .get(base) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let black = slice + .get(base + 1) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; Ok((white, black)) } @@ -1027,7 +1049,12 @@ impl PyGameView { let borrowed = self.data.borrow(py); let arr = borrowed.legal_move_count.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; - Ok(unsafe { *arr.uget([self.idx]) }) + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } /// Whether game parsed successfully. @@ -1036,7 +1063,12 @@ impl PyGameView { let borrowed = self.data.borrow(py); let arr = borrowed.valid.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; - Ok(unsafe { *arr.uget([self.idx]) }) + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } // === Convenience methods === @@ -1059,11 +1091,27 @@ impl PyGameView { let promo_arr = borrowed.promotions.bind(py); let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; + let from_ro = from_arr.readonly(); + let to_ro = to_arr.readonly(); + let promo_ro = promo_arr.readonly(); + + let from_slice = from_ro.as_slice()?; + let to_slice = to_ro.as_slice()?; + let promo_slice = promo_ro.as_slice()?; + let abs_idx = self.move_start + move_idx; - // SAFETY: we validated move_idx above and abs_idx is within bounds - let from_sq = unsafe { *from_arr.uget([abs_idx]) }; - let to_sq = unsafe { *to_arr.uget([abs_idx]) }; - let promo = unsafe { *promo_arr.uget([abs_idx]) }; + let from_sq = from_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let to_sq = to_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let promo = promo_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; let files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']; let ranks = ['1', '2', '3', '4', '5', '6', '7', '8']; From 077a3eddfd242242e0097e4a77ac614bbf60aa14 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:39:58 -0500 Subject: [PATCH 06/31] Fix up dtypes, remove old wrapper code --- python/wrapper.py | 301 ----------------------------- rust_pgn_reader_python_binding.pyi | 11 +- src/bench_parse_games_flat.py | 2 +- src/lib.rs | 31 ++- src/test.py | 24 +++ 5 files changed, 58 insertions(+), 311 deletions(-) delete mode 100644 python/wrapper.py diff --git a/python/wrapper.py b/python/wrapper.py deleted file mode 100644 index fc1ba45..0000000 --- a/python/wrapper.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -Ergonomic Python wrappers for rust_pgn_reader_python_binding.ParsedGames. - -Usage: - import rust_pgn_reader_python_binding as pgn - from python.wrapper import add_ergonomic_methods, GameView - - result = pgn.parse_games_flat(chunked_array) - add_ergonomic_methods(type(result)) - - for game in result: - print(game.headers) -""" - -from __future__ import annotations - -import numpy as np -from typing import Iterator, Dict, TYPE_CHECKING - -if TYPE_CHECKING: - from rust_pgn_reader_python_binding import ParsedGames - - -class GameView: - """ - Zero-copy view into a single game's data within a ParsedGames result. - - Board indexing note: Boards use square indexing (a1=0, h8=63). - To convert to rank/file array indexing used by some Python code: - rank = square // 8 - file = square % 8 - # For [7-rank, file] layout: board_2d[7 - rank, file] - """ - - __slots__ = ( - "_data", - "_idx", - "_move_start", - "_move_end", - "_pos_start", - "_pos_end", - ) - - def __init__(self, data: "ParsedGames", idx: int): - self._data = data - self._idx = idx - self._move_start = int(data.move_offsets[idx]) - self._move_end = int(data.move_offsets[idx + 1]) - self._pos_start = int(data.position_offsets[idx]) - self._pos_end = int(data.position_offsets[idx + 1]) - - def __len__(self) -> int: - """Number of moves in this game.""" - return self._move_end - self._move_start - - @property - def num_positions(self) -> int: - """Number of positions recorded for this game.""" - return self._pos_end - self._pos_start - - # === Board state views === - - @property - def boards(self) -> np.ndarray: - """Board positions, shape (num_positions, 8, 8).""" - return self._data.boards[self._pos_start : self._pos_end] - - @property - def initial_board(self) -> np.ndarray: - """Initial board, shape (8, 8).""" - return self._data.boards[self._pos_start] - - @property - def final_board(self) -> np.ndarray: - """Final board, shape (8, 8).""" - return self._data.boards[self._pos_end - 1] - - @property - def castling(self) -> np.ndarray: - """Castling rights [K,Q,k,q], shape (num_positions, 4).""" - return self._data.castling[self._pos_start : self._pos_end] - - @property - def en_passant(self) -> np.ndarray: - """En passant file (-1 if none), shape (num_positions,).""" - return self._data.en_passant[self._pos_start : self._pos_end] - - @property - def halfmove_clock(self) -> np.ndarray: - """Halfmove clock, shape (num_positions,).""" - return self._data.halfmove_clock[self._pos_start : self._pos_end] - - @property - def turn(self) -> np.ndarray: - """Side to move (True=white), shape (num_positions,).""" - return self._data.turn[self._pos_start : self._pos_end] - - # === Move views === - - @property - def from_squares(self) -> np.ndarray: - """From squares, shape (num_moves,).""" - return self._data.from_squares[self._move_start : self._move_end] - - @property - def to_squares(self) -> np.ndarray: - """To squares, shape (num_moves,).""" - return self._data.to_squares[self._move_start : self._move_end] - - @property - def promotions(self) -> np.ndarray: - """Promotions (-1=none), shape (num_moves,).""" - return self._data.promotions[self._move_start : self._move_end] - - @property - def clocks(self) -> np.ndarray: - """Clock times in seconds (NaN if missing), shape (num_moves,).""" - return self._data.clocks[self._move_start : self._move_end] - - @property - def evals(self) -> np.ndarray: - """Engine evals (NaN if missing), shape (num_moves,).""" - return self._data.evals[self._move_start : self._move_end] - - # === Per-game metadata === - - @property - def headers(self) -> Dict[str, str]: - """Raw PGN headers as dict.""" - return self._data.headers[self._idx] - - @property - def is_checkmate(self) -> bool: - """Final position is checkmate.""" - return bool(self._data.is_checkmate[self._idx]) - - @property - def is_stalemate(self) -> bool: - """Final position is stalemate.""" - return bool(self._data.is_stalemate[self._idx]) - - @property - def is_insufficient(self) -> tuple: - """Insufficient material (white, black).""" - return ( - bool(self._data.is_insufficient[self._idx, 0]), - bool(self._data.is_insufficient[self._idx, 1]), - ) - - @property - def legal_move_count(self) -> int: - """Legal moves in final position.""" - return int(self._data.legal_move_count[self._idx]) - - @property - def is_valid(self) -> bool: - """Whether game parsed successfully.""" - return bool(self._data.valid[self._idx]) - - # === Convenience methods === - - def move_uci(self, move_idx: int) -> str: - """Get UCI string for move at index.""" - files = "abcdefgh" - ranks = "12345678" - from_sq = int(self.from_squares[move_idx]) - to_sq = int(self.to_squares[move_idx]) - promo = int(self.promotions[move_idx]) - - uci = f"{files[from_sq % 8]}{ranks[from_sq // 8]}{files[to_sq % 8]}{ranks[to_sq // 8]}" - if promo >= 0: - promo_chars = {2: "n", 3: "b", 4: "r", 5: "q"} - uci += promo_chars.get(promo, "") - return uci - - def moves_uci(self) -> list: - """Get all moves as UCI strings.""" - return [self.move_uci(i) for i in range(len(self))] - - def __repr__(self) -> str: - white = self.headers.get("White", "?") - black = self.headers.get("Black", "?") - return ( - f"" - ) - - -class BatchSlice: - """Lazy iterator over a slice of games.""" - - __slots__ = ("_data", "_indices") - - def __init__(self, data: "ParsedGames", indices: range): - self._data = data - self._indices = indices - - def __iter__(self) -> Iterator[GameView]: - for i in self._indices: - yield GameView(self._data, i) - - def __len__(self) -> int: - return len(self._indices) - - def __repr__(self) -> str: - return f"" - - -# === Functions to add ergonomic methods to ParsedGames === - - -def _parsed_games_len(self) -> int: - """Number of games in result.""" - return len(self.move_offsets) - 1 - - -def _parsed_games_getitem(self, idx): - """Access game(s) by index or slice.""" - n_games = len(self.move_offsets) - 1 - if isinstance(idx, int): - if idx < 0: - idx += n_games - if not 0 <= idx < n_games: - raise IndexError(f"Game index {idx} out of range [0, {n_games})") - return GameView(self, idx) - elif isinstance(idx, slice): - start, stop, step = idx.indices(n_games) - return BatchSlice(self, range(start, stop, step)) - raise TypeError(f"Invalid index type: {type(idx)}") - - -def _parsed_games_iter(self) -> Iterator[GameView]: - """Iterate over all games.""" - for i in range(len(self.move_offsets) - 1): - yield GameView(self, i) - - -def _position_to_game(self, position_indices: np.ndarray) -> np.ndarray: - """ - Map position indices to game indices. - - Useful after shuffling/sampling positions to look up game metadata. - - Args: - position_indices: Array of indices into boards array - - Returns: - Array of game indices (same shape as input) - """ - return ( - np.searchsorted(self.position_offsets[:-1], position_indices, side="right") - 1 - ) - - -def _move_to_game(self, move_indices: np.ndarray) -> np.ndarray: - """ - Map move indices to game indices. - - Args: - move_indices: Array of indices into from_squares, to_squares, etc. - - Returns: - Array of game indices (same shape as input) - """ - return np.searchsorted(self.move_offsets[:-1], move_indices, side="right") - 1 - - -@property -def _num_games(self) -> int: - """Number of games.""" - return len(self.move_offsets) - 1 - - -@property -def _num_moves(self) -> int: - """Total moves across all games.""" - return int(self.move_offsets[-1]) - - -@property -def _num_positions(self) -> int: - """Total positions recorded.""" - return int(self.position_offsets[-1]) - - -def add_ergonomic_methods(parsed_games_class): - """ - Add ergonomic methods to the ParsedGames class. - - Call once after importing the module: - import rust_pgn_reader_python_binding as pgn - from python.wrapper import add_ergonomic_methods - add_ergonomic_methods(pgn.ParsedGames) - """ - parsed_games_class.__len__ = _parsed_games_len - parsed_games_class.__getitem__ = _parsed_games_getitem - parsed_games_class.__iter__ = _parsed_games_iter - parsed_games_class.position_to_game = _position_to_game - parsed_games_class.move_to_game = _move_to_game - parsed_games_class.num_games = _num_games - parsed_games_class.num_moves = _num_moves - parsed_games_class.num_positions = _num_positions diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index 86c4166..d75d603 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple, Dict, Iterator, Union, overload import pyarrow import numpy as np +import numpy.typing as npt from numpy.typing import NDArray class PyUciMove: @@ -337,26 +338,26 @@ class ParsedGames: # === Mapping utilities === - def position_to_game( - self, position_indices: NDArray[np.int64] - ) -> NDArray[np.int64]: + def position_to_game(self, position_indices: npt.ArrayLike) -> NDArray[np.int64]: """Map position indices to game indices. Useful after shuffling/sampling positions to look up game metadata. Args: - position_indices: Array of indices into boards array + position_indices: Array of indices into boards array. + Accepts any integer dtype; int64 is optimal (avoids conversion). Returns: Array of game indices (same shape as input) """ ... - def move_to_game(self, move_indices: NDArray[np.int64]) -> NDArray[np.int64]: + def move_to_game(self, move_indices: npt.ArrayLike) -> NDArray[np.int64]: """Map move indices to game indices. Args: move_indices: Array of indices into from_squares, to_squares, etc. + Accepts any integer dtype; int64 is optimal (avoids conversion). Returns: Array of game indices (same shape as input) diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games_flat.py index b5b848a..e92c739 100644 --- a/src/bench_parse_games_flat.py +++ b/src/bench_parse_games_flat.py @@ -167,7 +167,7 @@ def benchmark_data_access_flat(result) -> dict: _ = result.to_squares.sum() # Random position access - indices = np.random.randint(0, result.num_positions, size=1000) + indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) _ = result.boards[indices] # Position-to-game mapping diff --git a/src/lib.rs b/src/lib.rs index 612de6e..b2a57f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use arrow_array::{Array, LargeStringArray, StringArray}; use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; -use pyo3::types::PySlice; +use pyo3::types::{IntoPyDict, PySlice}; use pyo3_arrow::PyChunkedArray; use rayon::prelude::*; use rayon::ThreadPoolBuilder; @@ -697,14 +697,15 @@ impl ParsedGames { /// Useful after shuffling/sampling positions to look up game metadata. /// /// Args: - /// position_indices: Array of indices into boards array + /// position_indices: Array of indices into boards array. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). /// /// Returns: /// Array of game indices (same shape as input) fn position_to_game<'py>( &self, py: Python<'py>, - position_indices: &Bound<'py, PyArray1>, + position_indices: &Bound<'py, PyAny>, ) -> PyResult>> { let offsets = self.position_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; @@ -712,6 +713,16 @@ impl ParsedGames { // Get numpy module for searchsorted let numpy = py.import("numpy")?; + // Convert input to int64 array (no-op if already int64) + let int64_dtype = numpy.getattr("int64")?; + let position_indices = numpy + .call_method1("asarray", (position_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + // offsets[:-1] - all but last element let len = offsets.len(); let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); @@ -738,18 +749,30 @@ impl ParsedGames { /// /// Args: /// move_indices: Array of indices into from_squares, to_squares, etc. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). /// /// Returns: /// Array of game indices (same shape as input) fn move_to_game<'py>( &self, py: Python<'py>, - move_indices: &Bound<'py, PyArray1>, + move_indices: &Bound<'py, PyAny>, ) -> PyResult>> { let offsets = self.move_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; let numpy = py.import("numpy")?; + + // Convert input to int64 array (no-op if already int64) + let int64_dtype = numpy.getattr("int64")?; + let move_indices = numpy + .call_method1("asarray", (move_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + let len = offsets.len(); let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; diff --git a/src/test.py b/src/test.py index 32abe36..a189e55 100644 --- a/src/test.py +++ b/src/test.py @@ -869,6 +869,30 @@ def test_move_to_game_mapping(self): np.testing.assert_array_equal(game_indices, [0, 0, 1]) + def test_position_to_game_accepts_various_dtypes(self): + """Test position_to_game accepts various integer dtypes.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Test various integer dtypes (int64 is optimal, others are converted) + for dtype in [np.int32, np.int64, np.uint32, np.uint64]: + pos_indices = np.array([0, 1, 2, 3, 4], dtype=dtype) + game_indices = result.position_to_game(pos_indices) + np.testing.assert_array_equal(game_indices, [0, 0, 0, 1, 1]) + + def test_move_to_game_accepts_various_dtypes(self): + """Test move_to_game accepts various integer dtypes.""" + pgns = ["1. e4 e5 1-0", "1. d4 0-1"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + + # Test various integer dtypes (int64 is optimal, others are converted) + for dtype in [np.int32, np.int64, np.uint32, np.uint64]: + move_indices = np.array([0, 1, 2], dtype=dtype) + game_indices = result.move_to_game(move_indices) + np.testing.assert_array_equal(game_indices, [0, 0, 1]) + def test_clocks_and_evals(self): """Test clock and eval parsing.""" pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... e5 { [%eval 0.19] [%clk 0:00:29] } 1-0""" From c6ed5a0fb0eac80f4dee07f587dbfa22fec4bc0b Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:45:01 -0500 Subject: [PATCH 07/31] Delete unnecessary file --- python/__init__.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 python/__init__.py diff --git a/python/__init__.py b/python/__init__.py deleted file mode 100644 index 74ad9ae..0000000 --- a/python/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Python wrapper utilities for rust_pgn_reader_python_binding.""" - -from .wrapper import add_ergonomic_methods, GameView, BatchSlice - -__all__ = ["add_ergonomic_methods", "GameView", "BatchSlice"] From 476732b9e4696f9759d56a18796cadc8301aa57a Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 21:51:09 -0500 Subject: [PATCH 08/31] Tweak imports --- src/lib.rs | 9 +++------ src/test.py | 4 +--- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b2a57f1..55271ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,13 @@ -use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; +use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; use arrow_array::{Array, LargeStringArray, StringArray}; use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PySlice}; use pyo3_arrow::PyChunkedArray; -use rayon::prelude::*; use rayon::ThreadPoolBuilder; -use shakmaty::fen::Fen; -use shakmaty::CastlingMode; -use shakmaty::Color; -use shakmaty::{uci::UciMove, Chess, Position, Role, Square}; +use rayon::prelude::*; +use shakmaty::{CastlingMode, Chess, Color, Position, Role, Square, fen::Fen, uci::UciMove}; use std::collections::HashMap; use std::io::Cursor; use std::ops::ControlFlow; diff --git a/src/test.py b/src/test.py index a189e55..9aa41ad 100644 --- a/src/test.py +++ b/src/test.py @@ -1,10 +1,8 @@ import unittest -import sys -import os import numpy as np import rust_pgn_reader_python_binding -from rust_pgn_reader_python_binding import PyGameView +from rust_pgn_reader_python_binding import PyGameView # for a typing check import pyarrow as pa From cfb1e227a873bad82ed3ce87ff745a78fa751aa4 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 4 Feb 2026 22:16:42 -0500 Subject: [PATCH 09/31] Split up lib.rs into smaller files --- src/lib.rs | 1241 +--------------------------------------- src/python_bindings.rs | 825 ++++++++++++++++++++++++++ src/visitor.rs | 409 +++++++++++++ 3 files changed, 1241 insertions(+), 1234 deletions(-) create mode 100644 src/python_bindings.rs create mode 100644 src/visitor.rs diff --git a/src/lib.rs b/src/lib.rs index 55271ef..94e0b46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,1181 +1,20 @@ -use crate::comment_parsing::{CommentContent, ParsedTag, parse_comments}; use arrow_array::{Array, LargeStringArray, StringArray}; -use numpy::{PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods}; -use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, Reader, SanPlus, Skip, Visitor}; +use numpy::{PyArray1, PyArrayMethods}; +use pgn_reader::Reader; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PySlice}; use pyo3_arrow::PyChunkedArray; -use rayon::ThreadPoolBuilder; use rayon::prelude::*; -use shakmaty::{CastlingMode, Chess, Color, Position, Role, Square, fen::Fen, uci::UciMove}; +use rayon::ThreadPoolBuilder; use std::collections::HashMap; use std::io::Cursor; -use std::ops::ControlFlow; mod board_serialization; mod comment_parsing; +mod python_bindings; +mod visitor; -use board_serialization::{ - get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, -}; - -// Definition of PyUciMove -#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] -#[derive(Clone, Debug)] -pub struct PyUciMove { - pub from_square: u8, - pub to_square: u8, - pub promotion: Option, -} - -#[pymethods] -impl PyUciMove { - #[new] - fn new(from_square: u8, to_square: u8, promotion: Option) -> Self { - PyUciMove { - from_square, - to_square, - promotion, - } - } - - #[getter] - fn get_from_square_name(&self) -> String { - Square::new(self.from_square as u32).to_string() - } - - #[getter] - fn get_to_square_name(&self) -> String { - Square::new(self.to_square as u32).to_string() - } - - #[getter] - fn get_promotion_name(&self) -> Option { - self.promotion.and_then(|p_u8| { - Role::try_from(p_u8) - .map(|role| format!("{:?}", role)) // Get the debug representation (e.g., "Queen") - .ok() - }) - } - - // __str__ method for Python representation - fn __str__(&self) -> String { - let promo_str = self.promotion.map_or("".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| role.char().to_string()) - .unwrap_or_else(|_| "".to_string()) // Handle potential error if u8 is not a valid Role - }); - format!( - "{}{}{}", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_str - ) - } - - // __repr__ for a more developer-friendly representation - fn __repr__(&self) -> String { - let promo_repr = self.promotion.map_or("None".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| format!("Some('{}')", role.char())) - .unwrap_or_else(|_| format!("Some(InvalidRole({}))", p_u8)) - }); - format!( - "PyUciMove(from_square={}, to_square={}, promotion={})", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_repr - ) - } -} - -#[pyclass] -/// Holds the status of a chess position. -#[derive(Clone)] -pub struct PositionStatus { - #[pyo3(get)] - is_checkmate: bool, - - #[pyo3(get)] - is_stalemate: bool, - - #[pyo3(get)] - legal_move_count: usize, - - #[pyo3(get)] - is_game_over: bool, - - #[pyo3(get)] - insufficient_material: (bool, bool), - - #[pyo3(get)] - turn: bool, -} - -#[pyclass] -/// A Visitor to extract SAN moves and comments from PGN movetext -pub struct MoveExtractor { - #[pyo3(get)] - moves: Vec, - - store_legal_moves: bool, - flat_legal_moves: Vec, - legal_moves_offsets: Vec, - - #[pyo3(get)] - valid_moves: bool, - - #[pyo3(get)] - comments: Vec>, - - #[pyo3(get)] - evals: Vec>, - - #[pyo3(get)] - clock_times: Vec>, - - #[pyo3(get)] - outcome: Option, - - #[pyo3(get)] - headers: Vec<(String, String)>, - - #[pyo3(get)] - castling_rights: Vec>, - - #[pyo3(get)] - position_status: Option, - - pos: Chess, - - // Board state tracking for flat output (not directly exposed to Python) - board_states: Vec, // Flattened: 64 bytes per position - en_passant_states: Vec, // Per position: -1 or file 0-7 - halfmove_clocks: Vec, // Per position - turn_states: Vec, // Per position: true=white - castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] -} - -#[pymethods] -impl MoveExtractor { - #[new] - #[pyo3(signature = (store_legal_moves = false))] - fn new(store_legal_moves: bool) -> MoveExtractor { - MoveExtractor { - moves: Vec::with_capacity(100), - store_legal_moves, - flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves - legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets - pos: Chess::default(), - valid_moves: true, - comments: Vec::with_capacity(100), - evals: Vec::with_capacity(100), - clock_times: Vec::with_capacity(100), - outcome: None, - headers: Vec::with_capacity(10), - castling_rights: Vec::with_capacity(100), - position_status: None, - board_states: Vec::with_capacity(100 * 64), - en_passant_states: Vec::with_capacity(100), - halfmove_clocks: Vec::with_capacity(100), - turn_states: Vec::with_capacity(100), - castling_states: Vec::with_capacity(100 * 4), - } - } - - fn turn(&self) -> bool { - match self.pos.turn() { - Color::White => true, - Color::Black => false, - } - } - - fn push_castling_bitboards(&mut self) { - let castling_bitboard = self.pos.castles().castling_rights(); - let castling_rights = ( - castling_bitboard.contains(shakmaty::Square::A1), - castling_bitboard.contains(shakmaty::Square::H1), - castling_bitboard.contains(shakmaty::Square::A8), - castling_bitboard.contains(shakmaty::Square::H8), - ); - - self.castling_rights.push(Some(castling_rights)); - } - - fn push_legal_moves(&mut self) { - // Record the starting offset for the current position's legal moves. - self.legal_moves_offsets.push(self.flat_legal_moves.len()); - - let legal_moves_for_pos = self.pos.legal_moves(); - self.flat_legal_moves.reserve(legal_moves_for_pos.len()); - - for m in legal_moves_for_pos { - let uci_move_obj = UciMove::from_standard(m); - if let UciMove::Normal { - from, - to, - promotion: promo_opt, - } = uci_move_obj - { - self.flat_legal_moves.push(PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }); - } - } - } - - /// Record current board state to flat arrays for ParsedGames output. - fn push_board_state(&mut self) { - self.board_states - .extend_from_slice(&serialize_board(&self.pos)); - self.en_passant_states.push(get_en_passant_file(&self.pos)); - self.halfmove_clocks.push(get_halfmove_clock(&self.pos)); - self.turn_states.push(get_turn(&self.pos)); - let castling = get_castling_rights(&self.pos); - self.castling_states.extend_from_slice(&castling); - } - - fn update_position_status(&mut self) { - // TODO this checks legal_moves() a bunch of times - self.position_status = Some(PositionStatus { - is_checkmate: self.pos.is_checkmate(), - is_stalemate: self.pos.is_stalemate(), - legal_move_count: self.pos.legal_moves().len(), - is_game_over: self.pos.is_game_over(), - insufficient_material: ( - self.pos.has_insufficient_material(Color::White), - self.pos.has_insufficient_material(Color::Black), - ), - turn: match self.pos.turn() { - Color::White => true, - Color::Black => false, - }, - }); - } - - #[getter] - fn legal_moves(&self) -> Vec> { - let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); - if self.legal_moves_offsets.is_empty() { - return result; - } - - for i in 0..self.legal_moves_offsets.len() - 1 { - let start = self.legal_moves_offsets[i]; - let end = self.legal_moves_offsets[i + 1]; - result.push(self.flat_legal_moves[start..end].to_vec()); - } - - // Handle the last chunk - if let Some(&start) = self.legal_moves_offsets.last() { - result.push(self.flat_legal_moves[start..].to_vec()); - } - - result - } -} - -impl Visitor for MoveExtractor { - type Tags = Vec<(String, String)>; - type Movetext = (); - type Output = bool; - - fn begin_tags(&mut self) -> ControlFlow { - self.headers.clear(); - ControlFlow::Continue(Vec::with_capacity(10)) - } - - fn tag( - &mut self, - tags: &mut Self::Tags, - key: &[u8], - value: RawTag<'_>, - ) -> ControlFlow { - let key_str = String::from_utf8_lossy(key).into_owned(); - let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); - tags.push((key_str, value_str)); - ControlFlow::Continue(()) - } - - fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { - self.headers = tags; - self.moves.clear(); - self.flat_legal_moves.clear(); - self.legal_moves_offsets.clear(); - self.valid_moves = true; - self.comments.clear(); - self.evals.clear(); - self.clock_times.clear(); - self.castling_rights.clear(); - self.board_states.clear(); - self.en_passant_states.clear(); - self.halfmove_clocks.clear(); - self.turn_states.clear(); - self.castling_states.clear(); - - // Determine castling mode from Variant header (case-insensitive) - let castling_mode = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) - .and_then(|(_, v)| { - let v_lower = v.to_lowercase(); - if v_lower == "chess960" { - Some(CastlingMode::Chess960) - } else { - None - } - }) - .unwrap_or(CastlingMode::Standard); - - // Try to parse FEN from headers, fall back to default position - let fen_header = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) - .map(|(_, v)| v.as_str()); - - if let Some(fen_str) = fen_header { - match fen_str.parse::() { - Ok(fen) => match fen.into_position(castling_mode) { - Ok(pos) => self.pos = pos, - Err(e) => { - eprintln!("invalid FEN position: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - }, - Err(e) => { - eprintln!("failed to parse FEN: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - } - } else { - self.pos = Chess::default(); - } - - self.push_castling_bitboards(); - if self.store_legal_moves { - self.push_legal_moves(); - } - // Record initial board state for flat output - self.push_board_state(); - ControlFlow::Continue(()) - } - - // Roughly half the time during parsing is spent here in san() - fn san( - &mut self, - _movetext: &mut Self::Movetext, - san_plus: SanPlus, - ) -> ControlFlow { - if self.valid_moves { - // Most of the function time is spent calculating to_move() - match san_plus.san.to_move(&self.pos) { - Ok(m) => { - self.pos.play_unchecked(m); - if self.store_legal_moves { - self.push_legal_moves(); - } - // Record board state after move for flat output - self.push_board_state(); - let uci_move_obj = UciMove::from_standard(m); - - match uci_move_obj { - UciMove::Normal { - from, - to, - promotion: promo_opt, - } => { - let py_uci_move = PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }; - self.moves.push(py_uci_move); - self.push_castling_bitboards(); - - // Push placeholders to keep vectors in sync - self.comments.push(None); - self.evals.push(None); - self.clock_times.push(None); - } - _ => { - // This case handles UciMove::Put and UciMove::Null, - // which are not expected from standard PGN moves - // that PyUciMove is designed to represent. - eprintln!( - "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", - uci_move_obj - ); - self.valid_moves = false; - } - } - } - Err(err) => { - eprintln!("error in game: {} {}", err, san_plus); - self.valid_moves = false; - } - } - } - ControlFlow::Continue(()) - } - - fn comment( - &mut self, - _movetext: &mut Self::Movetext, - _comment: RawComment<'_>, - ) -> ControlFlow { - match parse_comments(_comment.as_bytes()) { - Ok((remaining_input, parsed_comments)) => { - if !remaining_input.is_empty() { - eprintln!("Unparsed remaining input: {:?}", remaining_input); - return ControlFlow::Continue(()); - } - - let mut move_comments = String::new(); - - for content in parsed_comments { - match content { - CommentContent::Text(text) => { - if !text.trim().is_empty() { - if !move_comments.is_empty() { - move_comments.push(' '); - } - move_comments.push_str(&text); - } - } - CommentContent::Tag(tag_content) => match tag_content { - ParsedTag::Eval(eval_value) => { - if let Some(last_eval) = self.evals.last_mut() { - *last_eval = Some(eval_value); - } - } - ParsedTag::Mate(mate_value) => { - if !move_comments.is_empty() && !move_comments.ends_with(' ') { - move_comments.push(' '); - } - move_comments.push_str(&format!("[Mate {}]", mate_value)); - } - ParsedTag::ClkTime { - hours, - minutes, - seconds, - } => { - if let Some(last_clk) = self.clock_times.last_mut() { - *last_clk = Some((hours, minutes, seconds)); - } - } - }, - } - } - - if let Some(last_comment) = self.comments.last_mut() { - *last_comment = Some(move_comments); - } - } - Err(e) => { - eprintln!("Error parsing comment: {:?}", e); - } - } - ControlFlow::Continue(()) - } - - fn begin_variation( - &mut self, - _movetext: &mut Self::Movetext, - ) -> ControlFlow { - ControlFlow::Continue(Skip(true)) // stay in the mainline - } - - fn outcome( - &mut self, - _movetext: &mut Self::Movetext, - _outcome: Outcome, - ) -> ControlFlow { - self.outcome = Some(match _outcome { - Outcome::Known(known) => match known { - KnownOutcome::Decisive { winner } => format!("{:?}", winner), - KnownOutcome::Draw => "Draw".to_string(), - }, - Outcome::Unknown => "Unknown".to_string(), - }); - self.update_position_status(); - ControlFlow::Continue(()) - } - - fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { - self.valid_moves - } -} - -/// Flat array container for parsed chess games, optimized for ML training. -/// -/// # Indexing -/// - `N_games`: Number of games -/// - `N_moves`: Total moves across all games -/// - `N_positions`: Total board positions recorded (varies per game due to initial position + moves) -/// -/// # Board layout -/// Boards use square indexing: a1=0, b1=1, ..., h8=63 -/// Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk -#[pyclass] -pub struct ParsedGames { - // === Board state arrays (N_positions) === - /// Board positions, shape (N_positions, 8, 8), dtype uint8 - #[pyo3(get)] - boards: Py, - - /// Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool - #[pyo3(get)] - castling: Py, - - /// En passant file (-1 if none), shape (N_positions,), dtype int8 - #[pyo3(get)] - en_passant: Py, - - /// Halfmove clock, shape (N_positions,), dtype uint8 - #[pyo3(get)] - halfmove_clock: Py, - - /// Side to move (true=white), shape (N_positions,), dtype bool - #[pyo3(get)] - turn: Py, - - // === Move arrays (N_moves) === - /// From squares, shape (N_moves,), dtype uint8 - #[pyo3(get)] - from_squares: Py, - - /// To squares, shape (N_moves,), dtype uint8 - #[pyo3(get)] - to_squares: Py, - - /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8 - #[pyo3(get)] - promotions: Py, - - /// Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32 - #[pyo3(get)] - clocks: Py, - - /// Engine evals (NaN if missing), shape (N_moves,), dtype float32 - #[pyo3(get)] - evals: Py, - - // === Offsets === - /// Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 - /// Game i's moves: move_offsets[i]..move_offsets[i+1] - #[pyo3(get)] - move_offsets: Py, - - /// Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 - /// Game i's positions: position_offsets[i]..position_offsets[i+1] - #[pyo3(get)] - position_offsets: Py, - - // === Final position status (N_games) === - /// Final position is checkmate, shape (N_games,), dtype bool - #[pyo3(get)] - is_checkmate: Py, - - /// Final position is stalemate, shape (N_games,), dtype bool - #[pyo3(get)] - is_stalemate: Py, - - /// Insufficient material (white, black), shape (N_games, 2), dtype bool - #[pyo3(get)] - is_insufficient: Py, - - /// Legal move count in final position, shape (N_games,), dtype uint16 - #[pyo3(get)] - legal_move_count: Py, - - // === Parse status (N_games) === - /// Whether game parsed successfully, shape (N_games,), dtype bool - #[pyo3(get)] - valid: Py, - - // === Raw headers (N_games) === - /// Raw PGN headers as list of dicts - #[pyo3(get)] - headers: Vec>, -} - -#[pymethods] -impl ParsedGames { - /// Number of games in the result. - #[getter] - fn num_games(&self) -> usize { - self.headers.len() - } - - /// Total number of moves across all games. - #[getter] - fn num_moves(&self, py: Python<'_>) -> PyResult { - let offsets = self.move_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let readonly = offsets.readonly(); - let slice = readonly.as_slice()?; - Ok(slice.last().copied().unwrap_or(0) as usize) - } - - /// Total number of board positions recorded. - #[getter] - fn num_positions(&self, py: Python<'_>) -> PyResult { - let offsets = self.position_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let readonly = offsets.readonly(); - let slice = readonly.as_slice()?; - Ok(slice.last().copied().unwrap_or(0) as usize) - } - - fn __len__(&self) -> usize { - self.headers.len() - } - - fn __getitem__(slf: Py, py: Python<'_>, idx: &Bound<'_, PyAny>) -> PyResult> { - let n_games = slf.borrow(py).headers.len(); - - // Handle integer index - if let Ok(mut i) = idx.extract::() { - // Handle negative indexing - if i < 0 { - i += n_games as isize; - } - if i < 0 || i >= n_games as isize { - return Err(pyo3::exceptions::PyIndexError::new_err(format!( - "Game index {} out of range [0, {})", - i, n_games - ))); - } - let game_view = PyGameView::new(py, slf.clone_ref(py), i as usize)?; - return Ok(Py::new(py, game_view)?.into_any()); - } - - // Handle slice - if let Ok(slice) = idx.cast::() { - let indices = slice.indices(n_games as isize)?; - let start = indices.start as usize; - let stop = indices.stop as usize; - let step = indices.step as usize; - - // For simplicity, we return a list of PyGameView objects - let mut views: Vec> = Vec::new(); - let mut i = start; - while i < stop { - let game_view = PyGameView::new(py, slf.clone_ref(py), i)?; - views.push(Py::new(py, game_view)?); - i += step; - } - return Ok(pyo3::types::PyList::new(py, views)?.into_any().unbind()); - } - - Err(pyo3::exceptions::PyTypeError::new_err(format!( - "Invalid index type: expected int or slice, got {}", - idx.get_type().name()? - ))) - } - - fn __iter__(slf: Py, py: Python<'_>) -> PyResult { - let n_games = slf.borrow(py).headers.len(); - Ok(ParsedGamesIter { - data: slf, - index: 0, - length: n_games, - }) - } - - /// Map position indices to game indices. - /// - /// Useful after shuffling/sampling positions to look up game metadata. - /// - /// Args: - /// position_indices: Array of indices into boards array. - /// Accepts any integer dtype; int64 is optimal (avoids conversion). - /// - /// Returns: - /// Array of game indices (same shape as input) - fn position_to_game<'py>( - &self, - py: Python<'py>, - position_indices: &Bound<'py, PyAny>, - ) -> PyResult>> { - let offsets = self.position_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - - // Get numpy module for searchsorted - let numpy = py.import("numpy")?; - - // Convert input to int64 array (no-op if already int64) - let int64_dtype = numpy.getattr("int64")?; - let position_indices = numpy - .call_method1("asarray", (position_indices,))? - .call_method( - "astype", - (int64_dtype,), - Some(&[("copy", false)].into_py_dict(py)?), - )?; - - // offsets[:-1] - all but last element - let len = offsets.len(); - let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); - let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; - - // searchsorted(offsets[:-1], position_indices, side='right') - 1 - let result = numpy.call_method1( - "searchsorted", - ( - offsets_slice, - position_indices, - pyo3::types::PyString::new(py, "right"), - ), - )?; - - // Subtract 1 - let one = 1i64.into_pyobject(py)?; - let result = result.call_method1("__sub__", (one,))?; - - Ok(result.extract()?) - } - - /// Map move indices to game indices. - /// - /// Args: - /// move_indices: Array of indices into from_squares, to_squares, etc. - /// Accepts any integer dtype; int64 is optimal (avoids conversion). - /// - /// Returns: - /// Array of game indices (same shape as input) - fn move_to_game<'py>( - &self, - py: Python<'py>, - move_indices: &Bound<'py, PyAny>, - ) -> PyResult>> { - let offsets = self.move_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - - let numpy = py.import("numpy")?; - - // Convert input to int64 array (no-op if already int64) - let int64_dtype = numpy.getattr("int64")?; - let move_indices = numpy - .call_method1("asarray", (move_indices,))? - .call_method( - "astype", - (int64_dtype,), - Some(&[("copy", false)].into_py_dict(py)?), - )?; - - let len = offsets.len(); - let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); - let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; - - let result = numpy.call_method1( - "searchsorted", - ( - offsets_slice, - move_indices, - pyo3::types::PyString::new(py, "right"), - ), - )?; - - let one = 1i64.into_pyobject(py)?; - let result = result.call_method1("__sub__", (one,))?; - - Ok(result.extract()?) - } -} - -/// Iterator over games in a ParsedGames result. -#[pyclass] -pub struct ParsedGamesIter { - data: Py, - index: usize, - length: usize, -} - -#[pymethods] -impl ParsedGamesIter { - fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { - slf - } - - fn __next__(mut slf: PyRefMut<'_, Self>, py: Python<'_>) -> PyResult> { - if slf.index >= slf.length { - return Ok(None); - } - let game_view = PyGameView::new(py, slf.data.clone_ref(py), slf.index)?; - slf.index += 1; - Ok(Some(game_view)) - } -} - -/// Zero-copy view into a single game's data within a ParsedGames result. -/// -/// Board indexing note: Boards use square indexing (a1=0, h8=63). -/// To convert to rank/file: -/// rank = square // 8 -/// file = square % 8 -#[pyclass] -pub struct PyGameView { - data: Py, - idx: usize, - move_start: usize, - move_end: usize, - pos_start: usize, - pos_end: usize, -} - -impl PyGameView { - fn new(py: Python<'_>, data: Py, idx: usize) -> PyResult { - let borrowed = data.borrow(py); - - let move_offsets = borrowed.move_offsets.bind(py); - let move_offsets: &Bound<'_, PyArray1> = move_offsets.cast()?; - let pos_offsets = borrowed.position_offsets.bind(py); - let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; - - let move_offsets_ro = move_offsets.readonly(); - let move_offsets_slice = move_offsets_ro.as_slice()?; - let pos_offsets_ro = pos_offsets.readonly(); - let pos_offsets_slice = pos_offsets_ro.as_slice()?; - - let move_start = move_offsets_slice - .get(idx) - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - let move_end = move_offsets_slice - .get(idx + 1) - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - let pos_start = pos_offsets_slice - .get(idx) - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - let pos_end = pos_offsets_slice - .get(idx + 1) - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - - drop(borrowed); - - Ok(Self { - data, - idx, - move_start: *move_start as usize, - move_end: *move_end as usize, - pos_start: *pos_start as usize, - pos_end: *pos_end as usize, - }) - } -} - -#[pymethods] -impl PyGameView { - /// Number of moves in this game. - fn __len__(&self) -> usize { - self.move_end - self.move_start - } - - /// Number of positions recorded for this game. - #[getter] - fn num_positions(&self) -> usize { - self.pos_end - self.pos_start - } - - // === Board state views === - - /// Board positions, shape (num_positions, 8, 8). - #[getter] - fn boards<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); - let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); - let slice = boards.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Initial board position, shape (8, 8). - #[getter] - fn initial_board<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); - let slice = boards.call_method1("__getitem__", (self.pos_start,))?; - Ok(slice.unbind()) - } - - /// Final board position, shape (8, 8). - #[getter] - fn final_board<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); - let slice = boards.call_method1("__getitem__", (self.pos_end - 1,))?; - Ok(slice.unbind()) - } - - /// Castling rights [K,Q,k,q], shape (num_positions, 4). - #[getter] - fn castling<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.castling.bind(py); - let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// En passant file (-1 if none), shape (num_positions,). - #[getter] - fn en_passant<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.en_passant.bind(py); - let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Halfmove clock, shape (num_positions,). - #[getter] - fn halfmove_clock<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.halfmove_clock.bind(py); - let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Side to move (True=white), shape (num_positions,). - #[getter] - fn turn<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.turn.bind(py); - let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - // === Move views === - - /// From squares, shape (num_moves,). - #[getter] - fn from_squares<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.from_squares.bind(py); - let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// To squares, shape (num_moves,). - #[getter] - fn to_squares<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.to_squares.bind(py); - let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,). - #[getter] - fn promotions<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.promotions.bind(py); - let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Clock times in seconds (NaN if missing), shape (num_moves,). - #[getter] - fn clocks<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.clocks.bind(py); - let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - /// Engine evals (NaN if missing), shape (num_moves,). - #[getter] - fn evals<'py>(&self, py: Python<'py>) -> PyResult> { - let borrowed = self.data.borrow(py); - let arr = borrowed.evals.bind(py); - let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); - let slice = arr.call_method1("__getitem__", (slice_obj,))?; - Ok(slice.unbind()) - } - - // === Per-game metadata === - - /// Raw PGN headers as dict. - #[getter] - fn headers(&self, py: Python<'_>) -> PyResult> { - let borrowed = self.data.borrow(py); - Ok(borrowed.headers[self.idx].clone()) - } - - /// Final position is checkmate. - #[getter] - fn is_checkmate(&self, py: Python<'_>) -> PyResult { - let borrowed = self.data.borrow(py); - let arr = borrowed.is_checkmate.bind(py); - let arr: &Bound<'_, PyArray1> = arr.cast()?; - let readonly = arr.readonly(); - let slice = readonly.as_slice()?; - slice - .get(self.idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) - } - - /// Final position is stalemate. - #[getter] - fn is_stalemate(&self, py: Python<'_>) -> PyResult { - let borrowed = self.data.borrow(py); - let arr = borrowed.is_stalemate.bind(py); - let arr: &Bound<'_, PyArray1> = arr.cast()?; - let readonly = arr.readonly(); - let slice = readonly.as_slice()?; - slice - .get(self.idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) - } - - /// Insufficient material (white, black). - #[getter] - fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { - let borrowed = self.data.borrow(py); - let arr = borrowed.is_insufficient.bind(py); - let arr: &Bound<'_, PyArray2> = arr.cast()?; - let readonly = arr.readonly(); - let slice = readonly.as_slice()?; - // Array is shape (n_games, 2), so index is idx * 2 for white, idx * 2 + 1 for black - let base = self.idx * 2; - let white = slice - .get(base) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - let black = slice - .get(base + 1) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; - Ok((white, black)) - } - - /// Legal move count in final position. - #[getter] - fn legal_move_count(&self, py: Python<'_>) -> PyResult { - let borrowed = self.data.borrow(py); - let arr = borrowed.legal_move_count.bind(py); - let arr: &Bound<'_, PyArray1> = arr.cast()?; - let readonly = arr.readonly(); - let slice = readonly.as_slice()?; - slice - .get(self.idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) - } - - /// Whether game parsed successfully. - #[getter] - fn is_valid(&self, py: Python<'_>) -> PyResult { - let borrowed = self.data.borrow(py); - let arr = borrowed.valid.bind(py); - let arr: &Bound<'_, PyArray1> = arr.cast()?; - let readonly = arr.readonly(); - let slice = readonly.as_slice()?; - slice - .get(self.idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) - } - - // === Convenience methods === - - /// Get UCI string for move at index. - fn move_uci(&self, py: Python<'_>, move_idx: usize) -> PyResult { - if move_idx >= self.move_end - self.move_start { - return Err(pyo3::exceptions::PyIndexError::new_err(format!( - "Move index {} out of range [0, {})", - move_idx, - self.move_end - self.move_start - ))); - } - - let borrowed = self.data.borrow(py); - let from_arr = borrowed.from_squares.bind(py); - let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; - let to_arr = borrowed.to_squares.bind(py); - let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; - let promo_arr = borrowed.promotions.bind(py); - let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; - - let from_ro = from_arr.readonly(); - let to_ro = to_arr.readonly(); - let promo_ro = promo_arr.readonly(); - - let from_slice = from_ro.as_slice()?; - let to_slice = to_ro.as_slice()?; - let promo_slice = promo_ro.as_slice()?; - - let abs_idx = self.move_start + move_idx; - let from_sq = from_slice - .get(abs_idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; - let to_sq = to_slice - .get(abs_idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; - let promo = promo_slice - .get(abs_idx) - .copied() - .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; - - let files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']; - let ranks = ['1', '2', '3', '4', '5', '6', '7', '8']; - - let mut uci = format!( - "{}{}{}{}", - files[(from_sq % 8) as usize], - ranks[(from_sq / 8) as usize], - files[(to_sq % 8) as usize], - ranks[(to_sq / 8) as usize] - ); - - if promo >= 0 { - let promo_chars = ['_', '_', 'n', 'b', 'r', 'q']; // 2=N, 3=B, 4=R, 5=Q - if (promo as usize) < promo_chars.len() { - uci.push(promo_chars[promo as usize]); - } - } - - Ok(uci) - } - - /// Get all moves as UCI strings. - fn moves_uci(&self, py: Python<'_>) -> PyResult> { - let n_moves = self.move_end - self.move_start; - let mut result = Vec::with_capacity(n_moves); - for i in 0..n_moves { - result.push(self.move_uci(py, i)?); - } - Ok(result) - } - - fn __repr__(&self, py: Python<'_>) -> PyResult { - let headers = self.headers(py)?; - let white = headers.get("White").map(|s| s.as_str()).unwrap_or("?"); - let black = headers.get("Black").map(|s| s.as_str()).unwrap_or("?"); - let n_moves = self.move_end - self.move_start; - let is_valid = self.is_valid(py)?; - Ok(format!( - "", - white, black, n_moves, is_valid - )) - } -} +use python_bindings::{ParsedGames, ParsedGamesIter, PositionStatus, PyGameView, PyUciMove}; +use visitor::MoveExtractor; #[pyfunction] #[pyo3(signature = (pgn_chunked_array, num_threads=None))] @@ -1491,72 +330,6 @@ fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { #[cfg(test)] mod pyucimove_tests { use super::*; - use shakmaty::{Role, Square}; - - #[test] - fn test_py_uci_move_no_promotion() { - let uci_move = PyUciMove::new(Square::E2 as u8, Square::E4 as u8, None); - assert_eq!(uci_move.from_square, Square::E2 as u8); - assert_eq!(uci_move.to_square, Square::E4 as u8); - assert_eq!(uci_move.promotion, None); - assert_eq!(uci_move.get_from_square_name(), "e2"); - assert_eq!(uci_move.get_to_square_name(), "e4"); - assert_eq!(uci_move.get_promotion_name(), None); - assert_eq!(uci_move.__str__(), "e2e4"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e2, to_square=e4, promotion=None)" - ); - } - - #[test] - fn test_py_uci_move_with_queen_promotion() { - let uci_move = PyUciMove::new(Square::E7 as u8, Square::E8 as u8, Some(Role::Queen as u8)); - assert_eq!(uci_move.from_square, Square::E7 as u8); - assert_eq!(uci_move.to_square, Square::E8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Queen as u8)); - assert_eq!(uci_move.get_from_square_name(), "e7"); - assert_eq!(uci_move.get_to_square_name(), "e8"); - assert_eq!(uci_move.get_promotion_name(), Some("Queen".to_string())); - assert_eq!(uci_move.__str__(), "e7e8q"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e7, to_square=e8, promotion=Some('q'))" - ); - } - - #[test] - fn test_py_uci_move_with_rook_promotion() { - let uci_move = PyUciMove::new(Square::A7 as u8, Square::A8 as u8, Some(Role::Rook as u8)); - assert_eq!(uci_move.from_square, Square::A7 as u8); - assert_eq!(uci_move.to_square, Square::A8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Rook as u8)); - assert_eq!(uci_move.get_from_square_name(), "a7"); - assert_eq!(uci_move.get_to_square_name(), "a8"); - assert_eq!(uci_move.get_promotion_name(), Some("Rook".to_string())); - assert_eq!(uci_move.__str__(), "a7a8r"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=a7, to_square=a8, promotion=Some('r'))" - ); - } - - #[test] - fn test_py_uci_move_invalid_promotion_val() { - // Test with a u8 value that doesn't correspond to a valid Role - let uci_move = PyUciMove::new(Square::B7 as u8, Square::B8 as u8, Some(99)); // 99 is not a valid Role - assert_eq!(uci_move.from_square, Square::B7 as u8); - assert_eq!(uci_move.to_square, Square::B8 as u8); - assert_eq!(uci_move.promotion, Some(99)); - assert_eq!(uci_move.get_from_square_name(), "b7"); - assert_eq!(uci_move.get_to_square_name(), "b8"); - assert_eq!(uci_move.get_promotion_name(), None); // Should be None as 99 is invalid - assert_eq!(uci_move.__str__(), "b7b8"); // Should produce no promotion char - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=b7, to_square=b8, promotion=Some(InvalidRole(99)))" - ); - } #[test] fn test_parse_game_without_headers() { diff --git a/src/python_bindings.rs b/src/python_bindings.rs new file mode 100644 index 0000000..82feb14 --- /dev/null +++ b/src/python_bindings.rs @@ -0,0 +1,825 @@ +use numpy::{PyArray1, PyArray2, PyArrayMethods}; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PySlice}; +use shakmaty::{Role, Square}; +use std::collections::HashMap; + +// Definition of PyUciMove +#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] +#[derive(Clone, Debug)] +pub struct PyUciMove { + pub from_square: u8, + pub to_square: u8, + pub promotion: Option, +} + +#[pymethods] +impl PyUciMove { + #[new] + pub fn new(from_square: u8, to_square: u8, promotion: Option) -> Self { + PyUciMove { + from_square, + to_square, + promotion, + } + } + + #[getter] + fn get_from_square_name(&self) -> String { + Square::new(self.from_square as u32).to_string() + } + + #[getter] + fn get_to_square_name(&self) -> String { + Square::new(self.to_square as u32).to_string() + } + + #[getter] + fn get_promotion_name(&self) -> Option { + self.promotion.and_then(|p_u8| { + Role::try_from(p_u8) + .map(|role| format!("{:?}", role)) // Get the debug representation (e.g., "Queen") + .ok() + }) + } + + // __str__ method for Python representation + fn __str__(&self) -> String { + let promo_str = self.promotion.map_or("".to_string(), |p_u8| { + Role::try_from(p_u8) + .map(|role| role.char().to_string()) + .unwrap_or_else(|_| "".to_string()) // Handle potential error if u8 is not a valid Role + }); + format!( + "{}{}{}", + Square::new(self.from_square as u32), + Square::new(self.to_square as u32), + promo_str + ) + } + + // __repr__ for a more developer-friendly representation + fn __repr__(&self) -> String { + let promo_repr = self.promotion.map_or("None".to_string(), |p_u8| { + Role::try_from(p_u8) + .map(|role| format!("Some('{}')", role.char())) + .unwrap_or_else(|_| format!("Some(InvalidRole({}))", p_u8)) + }); + format!( + "PyUciMove(from_square={}, to_square={}, promotion={})", + Square::new(self.from_square as u32), + Square::new(self.to_square as u32), + promo_repr + ) + } +} + +#[pyclass] +/// Holds the status of a chess position. +#[derive(Clone)] +pub struct PositionStatus { + #[pyo3(get)] + pub is_checkmate: bool, + + #[pyo3(get)] + pub is_stalemate: bool, + + #[pyo3(get)] + pub legal_move_count: usize, + + #[pyo3(get)] + pub is_game_over: bool, + + #[pyo3(get)] + pub insufficient_material: (bool, bool), + + #[pyo3(get)] + pub turn: bool, +} + +/// Flat array container for parsed chess games, optimized for ML training. +#[pyclass] +pub struct ParsedGames { + // === Board state arrays (N_positions) === + /// Board positions, shape (N_positions, 8, 8), dtype uint8 + #[pyo3(get)] + pub boards: Py, + + /// Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool + #[pyo3(get)] + pub castling: Py, + + /// En passant file (-1 if none), shape (N_positions,), dtype int8 + #[pyo3(get)] + pub en_passant: Py, + + /// Halfmove clock, shape (N_positions,), dtype uint8 + #[pyo3(get)] + pub halfmove_clock: Py, + + /// Side to move (true=white), shape (N_positions,), dtype bool + #[pyo3(get)] + pub turn: Py, + + // === Move arrays (N_moves) === + /// From squares, shape (N_moves,), dtype uint8 + #[pyo3(get)] + pub from_squares: Py, + + /// To squares, shape (N_moves,), dtype uint8 + #[pyo3(get)] + pub to_squares: Py, + + /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8 + #[pyo3(get)] + pub promotions: Py, + + /// Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32 + #[pyo3(get)] + pub clocks: Py, + + /// Engine evals (NaN if missing), shape (N_moves,), dtype float32 + #[pyo3(get)] + pub evals: Py, + + // === Offsets === + /// Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 + /// Game i's moves: move_offsets[i]..move_offsets[i+1] + #[pyo3(get)] + pub move_offsets: Py, + + /// Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 + /// Game i's positions: position_offsets[i]..position_offsets[i+1] + #[pyo3(get)] + pub position_offsets: Py, + + // === Final position status (N_games) === + /// Final position is checkmate, shape (N_games,), dtype bool + #[pyo3(get)] + pub is_checkmate: Py, + + /// Final position is stalemate, shape (N_games,), dtype bool + #[pyo3(get)] + pub is_stalemate: Py, + + /// Insufficient material (white, black), shape (N_games, 2), dtype bool + #[pyo3(get)] + pub is_insufficient: Py, + + /// Legal move count in final position, shape (N_games,), dtype uint16 + #[pyo3(get)] + pub legal_move_count: Py, + + // === Parse status (N_games) === + /// Whether game parsed successfully, shape (N_games,), dtype bool + #[pyo3(get)] + pub valid: Py, + + // === Raw headers (N_games) === + /// Raw PGN headers as list of dicts + #[pyo3(get)] + pub headers: Vec>, +} + +#[pymethods] +impl ParsedGames { + /// Number of games in the result. + #[getter] + fn num_games(&self) -> usize { + self.headers.len() + } + + /// Total number of moves across all games. + #[getter] + fn num_moves(&self, py: Python<'_>) -> PyResult { + let offsets = self.move_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + let readonly = offsets.readonly(); + let slice = readonly.as_slice()?; + Ok(slice.last().copied().unwrap_or(0) as usize) + } + + /// Total number of board positions recorded. + #[getter] + fn num_positions(&self, py: Python<'_>) -> PyResult { + let offsets = self.position_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + let readonly = offsets.readonly(); + let slice = readonly.as_slice()?; + Ok(slice.last().copied().unwrap_or(0) as usize) + } + + fn __len__(&self) -> usize { + self.headers.len() + } + + fn __getitem__(slf: Py, py: Python<'_>, idx: &Bound<'_, PyAny>) -> PyResult> { + let n_games = slf.borrow(py).headers.len(); + + // Handle integer index + if let Ok(mut i) = idx.extract::() { + // Handle negative indexing + if i < 0 { + i += n_games as isize; + } + if i < 0 || i >= n_games as isize { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Game index {} out of range [0, {})", + i, n_games + ))); + } + let game_view = PyGameView::new(py, slf.clone_ref(py), i as usize)?; + return Ok(Py::new(py, game_view)?.into_any()); + } + + // Handle slice + if let Ok(slice) = idx.cast::() { + let indices = slice.indices(n_games as isize)?; + let start = indices.start as usize; + let stop = indices.stop as usize; + let step = indices.step as usize; + + // For simplicity, we return a list of PyGameView objects + let mut views: Vec> = Vec::new(); + let mut i = start; + while i < stop { + let game_view = PyGameView::new(py, slf.clone_ref(py), i)?; + views.push(Py::new(py, game_view)?); + i += step; + } + return Ok(pyo3::types::PyList::new(py, views)?.into_any().unbind()); + } + + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Invalid index type: expected int or slice, got {}", + idx.get_type().name()? + ))) + } + + fn __iter__(slf: Py, py: Python<'_>) -> PyResult { + let n_games = slf.borrow(py).headers.len(); + Ok(ParsedGamesIter { + data: slf, + index: 0, + length: n_games, + }) + } + + /// Map position indices to game indices. + /// + /// Useful after shuffling/sampling positions to look up game metadata. + /// + /// Args: + /// position_indices: Array of indices into boards array. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). + /// + /// Returns: + /// Array of game indices (same shape as input) + fn position_to_game<'py>( + &self, + py: Python<'py>, + position_indices: &Bound<'py, PyAny>, + ) -> PyResult>> { + let offsets = self.position_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + + // Get numpy module for searchsorted + let numpy = py.import("numpy")?; + + // Convert input to int64 array (no-op if already int64) + let int64_dtype = numpy.getattr("int64")?; + let position_indices = numpy + .call_method1("asarray", (position_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + + // offsets[:-1] - all but last element + let len = offsets.len()?; + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + // searchsorted(offsets[:-1], position_indices, side='right') - 1 + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + position_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + // Subtract 1 + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } + + /// Map move indices to game indices. + /// + /// Args: + /// move_indices: Array of indices into from_squares, to_squares, etc. + /// Accepts any integer dtype; int64 is optimal (avoids conversion). + /// + /// Returns: + /// Array of game indices (same shape as input) + fn move_to_game<'py>( + &self, + py: Python<'py>, + move_indices: &Bound<'py, PyAny>, + ) -> PyResult>> { + let offsets = self.move_offsets.bind(py); + let offsets: &Bound<'_, PyArray1> = offsets.cast()?; + + let numpy = py.import("numpy")?; + + // Convert input to int64 array (no-op if already int64) + let int64_dtype = numpy.getattr("int64")?; + let move_indices = numpy + .call_method1("asarray", (move_indices,))? + .call_method( + "astype", + (int64_dtype,), + Some(&[("copy", false)].into_py_dict(py)?), + )?; + + let len = offsets.len()?; + let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); + let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; + + let result = numpy.call_method1( + "searchsorted", + ( + offsets_slice, + move_indices, + pyo3::types::PyString::new(py, "right"), + ), + )?; + + let one = 1i64.into_pyobject(py)?; + let result = result.call_method1("__sub__", (one,))?; + + Ok(result.extract()?) + } +} + +/// Iterator over games in a ParsedGames result. +#[pyclass] +pub struct ParsedGamesIter { + data: Py, + index: usize, + length: usize, +} + +#[pymethods] +impl ParsedGamesIter { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>, py: Python<'_>) -> PyResult> { + if slf.index >= slf.length { + return Ok(None); + } + let game_view = PyGameView::new(py, slf.data.clone_ref(py), slf.index)?; + slf.index += 1; + Ok(Some(game_view)) + } +} + +/// Zero-copy view into a single game's data within a ParsedGames result. +/// +/// Board indexing note: Boards use square indexing (a1=0, h8=63). +/// To convert to rank/file: +/// rank = square // 8 +/// file = square % 8 +#[pyclass] +pub struct PyGameView { + data: Py, + idx: usize, + move_start: usize, + move_end: usize, + pos_start: usize, + pos_end: usize, +} + +impl PyGameView { + pub fn new(py: Python<'_>, data: Py, idx: usize) -> PyResult { + let borrowed = data.borrow(py); + + let move_offsets = borrowed.move_offsets.bind(py); + let move_offsets: &Bound<'_, PyArray1> = move_offsets.cast()?; + let pos_offsets = borrowed.position_offsets.bind(py); + let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; + + let move_offsets_ro = move_offsets.readonly(); + let move_offsets_slice = move_offsets_ro.as_slice()?; + let pos_offsets_ro = pos_offsets.readonly(); + let pos_offsets_slice = pos_offsets_ro.as_slice()?; + + let move_start = move_offsets_slice + .get(idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let move_end = move_offsets_slice + .get(idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_start = pos_offsets_slice + .get(idx) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let pos_end = pos_offsets_slice + .get(idx + 1) + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + + drop(borrowed); + + Ok(Self { + data, + idx, + move_start: *move_start as usize, + move_end: *move_end as usize, + pos_start: *pos_start as usize, + pos_end: *pos_end as usize, + }) + } +} + +#[pymethods] +impl PyGameView { + /// Number of moves in this game. + fn __len__(&self) -> usize { + self.move_end - self.move_start + } + + /// Number of positions recorded for this game. + #[getter] + fn num_positions(&self) -> usize { + self.pos_end - self.pos_start + } + + // === Board state views === + + /// Board positions, shape (num_positions, 8, 8). + #[getter] + fn boards<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = boards.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Initial board position, shape (8, 8). + #[getter] + fn initial_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_start,))?; + Ok(slice.unbind()) + } + + /// Final board position, shape (8, 8). + #[getter] + fn final_board<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let boards = borrowed.boards.bind(py); + let slice = boards.call_method1("__getitem__", (self.pos_end - 1,))?; + Ok(slice.unbind()) + } + + /// Castling rights [K,Q,k,q], shape (num_positions, 4). + #[getter] + fn castling<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.castling.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// En passant file (-1 if none), shape (num_positions,). + #[getter] + fn en_passant<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.en_passant.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Halfmove clock, shape (num_positions,). + #[getter] + fn halfmove_clock<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.halfmove_clock.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Side to move (True=white), shape (num_positions,). + #[getter] + fn turn<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.turn.bind(py); + let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Move views === + + /// From squares, shape (num_moves,). + #[getter] + fn from_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.from_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// To squares, shape (num_moves,). + #[getter] + fn to_squares<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.to_squares.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (num_moves,). + #[getter] + fn promotions<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.promotions.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Clock times in seconds (NaN if missing), shape (num_moves,). + #[getter] + fn clocks<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.clocks.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + /// Engine evals (NaN if missing), shape (num_moves,). + #[getter] + fn evals<'py>(&self, py: Python<'py>) -> PyResult> { + let borrowed = self.data.borrow(py); + let arr = borrowed.evals.bind(py); + let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); + let slice = arr.call_method1("__getitem__", (slice_obj,))?; + Ok(slice.unbind()) + } + + // === Per-game metadata === + + /// Raw PGN headers as dict. + #[getter] + fn headers(&self, py: Python<'_>) -> PyResult> { + let borrowed = self.data.borrow(py); + Ok(borrowed.headers[self.idx].clone()) + } + + /// Final position is checkmate. + #[getter] + fn is_checkmate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_checkmate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Final position is stalemate. + #[getter] + fn is_stalemate(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_stalemate.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Insufficient material (white, black). + #[getter] + fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { + let borrowed = self.data.borrow(py); + let arr = borrowed.is_insufficient.bind(py); + let arr: &Bound<'_, PyArray2> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + // Array is shape (n_games, 2), so index is idx * 2 for white, idx * 2 + 1 for black + let base = self.idx * 2; + let white = slice + .get(base) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + let black = slice + .get(base + 1) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; + Ok((white, black)) + } + + /// Legal move count in final position. + #[getter] + fn legal_move_count(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.legal_move_count.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + /// Whether game parsed successfully. + #[getter] + fn is_valid(&self, py: Python<'_>) -> PyResult { + let borrowed = self.data.borrow(py); + let arr = borrowed.valid.bind(py); + let arr: &Bound<'_, PyArray1> = arr.cast()?; + let readonly = arr.readonly(); + let slice = readonly.as_slice()?; + slice + .get(self.idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) + } + + // === Convenience methods === + + /// Get UCI string for move at index. + fn move_uci(&self, py: Python<'_>, move_idx: usize) -> PyResult { + if move_idx >= self.move_end - self.move_start { + return Err(pyo3::exceptions::PyIndexError::new_err(format!( + "Move index {} out of range [0, {})", + move_idx, + self.move_end - self.move_start + ))); + } + + let borrowed = self.data.borrow(py); + let from_arr = borrowed.from_squares.bind(py); + let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; + let to_arr = borrowed.to_squares.bind(py); + let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; + let promo_arr = borrowed.promotions.bind(py); + let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; + + let from_ro = from_arr.readonly(); + let to_ro = to_arr.readonly(); + let promo_ro = promo_arr.readonly(); + + let from_slice = from_ro.as_slice()?; + let to_slice = to_ro.as_slice()?; + let promo_slice = promo_ro.as_slice()?; + + let abs_idx = self.move_start + move_idx; + let from_sq = from_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let to_sq = to_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + let promo = promo_slice + .get(abs_idx) + .copied() + .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid move index"))?; + + let files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']; + let ranks = ['1', '2', '3', '4', '5', '6', '7', '8']; + + let mut uci = format!( + "{}{}{}{}", + files[(from_sq % 8) as usize], + ranks[(from_sq / 8) as usize], + files[(to_sq % 8) as usize], + ranks[(to_sq / 8) as usize] + ); + + if promo >= 0 { + let promo_chars = ['_', '_', 'n', 'b', 'r', 'q']; // 2=N, 3=B, 4=R, 5=Q + if (promo as usize) < promo_chars.len() { + uci.push(promo_chars[promo as usize]); + } + } + + Ok(uci) + } + + /// Get all moves as UCI strings. + fn moves_uci(&self, py: Python<'_>) -> PyResult> { + let n_moves = self.move_end - self.move_start; + let mut result = Vec::with_capacity(n_moves); + for i in 0..n_moves { + result.push(self.move_uci(py, i)?); + } + Ok(result) + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + let headers = self.headers(py)?; + let white = headers.get("White").map(|s| s.as_str()).unwrap_or("?"); + let black = headers.get("Black").map(|s| s.as_str()).unwrap_or("?"); + let n_moves = self.move_end - self.move_start; + let is_valid = self.is_valid(py)?; + Ok(format!( + "", + white, black, n_moves, is_valid + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use shakmaty::{Role, Square}; + + #[test] + fn test_py_uci_move_no_promotion() { + let uci_move = PyUciMove::new(Square::E2 as u8, Square::E4 as u8, None); + assert_eq!(uci_move.from_square, Square::E2 as u8); + assert_eq!(uci_move.to_square, Square::E4 as u8); + assert_eq!(uci_move.promotion, None); + assert_eq!(uci_move.get_from_square_name(), "e2"); + assert_eq!(uci_move.get_to_square_name(), "e4"); + assert_eq!(uci_move.get_promotion_name(), None); + assert_eq!(uci_move.__str__(), "e2e4"); + assert_eq!( + uci_move.__repr__(), + "PyUciMove(from_square=e2, to_square=e4, promotion=None)" + ); + } + + #[test] + fn test_py_uci_move_with_queen_promotion() { + let uci_move = PyUciMove::new(Square::E7 as u8, Square::E8 as u8, Some(Role::Queen as u8)); + assert_eq!(uci_move.from_square, Square::E7 as u8); + assert_eq!(uci_move.to_square, Square::E8 as u8); + assert_eq!(uci_move.promotion, Some(Role::Queen as u8)); + assert_eq!(uci_move.get_from_square_name(), "e7"); + assert_eq!(uci_move.get_to_square_name(), "e8"); + assert_eq!(uci_move.get_promotion_name(), Some("Queen".to_string())); + assert_eq!(uci_move.__str__(), "e7e8q"); + assert_eq!( + uci_move.__repr__(), + "PyUciMove(from_square=e7, to_square=e8, promotion=Some('q'))" + ); + } + + #[test] + fn test_py_uci_move_with_rook_promotion() { + let uci_move = PyUciMove::new(Square::A7 as u8, Square::A8 as u8, Some(Role::Rook as u8)); + assert_eq!(uci_move.from_square, Square::A7 as u8); + assert_eq!(uci_move.to_square, Square::A8 as u8); + assert_eq!(uci_move.promotion, Some(Role::Rook as u8)); + assert_eq!(uci_move.get_from_square_name(), "a7"); + assert_eq!(uci_move.get_to_square_name(), "a8"); + assert_eq!(uci_move.get_promotion_name(), Some("Rook".to_string())); + assert_eq!(uci_move.__str__(), "a7a8r"); + assert_eq!( + uci_move.__repr__(), + "PyUciMove(from_square=a7, to_square=a8, promotion=Some('r'))" + ); + } + + #[test] + fn test_py_uci_move_invalid_promotion_val() { + // Test with a u8 value that doesn't correspond to a valid Role + let uci_move = PyUciMove::new(Square::B7 as u8, Square::B8 as u8, Some(99)); // 99 is not a valid Role + assert_eq!(uci_move.from_square, Square::B7 as u8); + assert_eq!(uci_move.to_square, Square::B8 as u8); + assert_eq!(uci_move.promotion, Some(99)); + assert_eq!(uci_move.get_from_square_name(), "b7"); + assert_eq!(uci_move.get_to_square_name(), "b8"); + assert_eq!(uci_move.get_promotion_name(), None); // Should be None as 99 is invalid + assert_eq!(uci_move.__str__(), "b7b8"); // Should produce no promotion char + assert_eq!( + uci_move.__repr__(), + "PyUciMove(from_square=b7, to_square=b8, promotion=Some(InvalidRole(99)))" + ); + } +} diff --git a/src/visitor.rs b/src/visitor.rs new file mode 100644 index 0000000..bc1a703 --- /dev/null +++ b/src/visitor.rs @@ -0,0 +1,409 @@ +use crate::board_serialization::{ + get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, +}; +use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; +use crate::python_bindings::{PositionStatus, PyUciMove}; +use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; +use pyo3::prelude::*; +use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; +use std::ops::ControlFlow; + +#[pyclass] +/// A Visitor to extract SAN moves and comments from PGN movetext +pub struct MoveExtractor { + #[pyo3(get)] + pub moves: Vec, + + pub store_legal_moves: bool, + pub flat_legal_moves: Vec, + pub legal_moves_offsets: Vec, + + #[pyo3(get)] + pub valid_moves: bool, + + #[pyo3(get)] + pub comments: Vec>, + + #[pyo3(get)] + pub evals: Vec>, + + #[pyo3(get)] + pub clock_times: Vec>, + + #[pyo3(get)] + pub outcome: Option, + + #[pyo3(get)] + pub headers: Vec<(String, String)>, + + #[pyo3(get)] + pub castling_rights: Vec>, + + #[pyo3(get)] + pub position_status: Option, + + pub pos: Chess, + + // Board state tracking for flat output (not directly exposed to Python) + pub board_states: Vec, // Flattened: 64 bytes per position + pub en_passant_states: Vec, // Per position: -1 or file 0-7 + pub halfmove_clocks: Vec, // Per position + pub turn_states: Vec, // Per position: true=white + pub castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] +} + +#[pymethods] +impl MoveExtractor { + #[new] + #[pyo3(signature = (store_legal_moves = false))] + pub fn new(store_legal_moves: bool) -> MoveExtractor { + MoveExtractor { + moves: Vec::with_capacity(100), + store_legal_moves, + flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves + legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets + pos: Chess::default(), + valid_moves: true, + comments: Vec::with_capacity(100), + evals: Vec::with_capacity(100), + clock_times: Vec::with_capacity(100), + outcome: None, + headers: Vec::with_capacity(10), + castling_rights: Vec::with_capacity(100), + position_status: None, + board_states: Vec::with_capacity(100 * 64), + en_passant_states: Vec::with_capacity(100), + halfmove_clocks: Vec::with_capacity(100), + turn_states: Vec::with_capacity(100), + castling_states: Vec::with_capacity(100 * 4), + } + } + + fn turn(&self) -> bool { + match self.pos.turn() { + Color::White => true, + Color::Black => false, + } + } + + fn push_castling_bitboards(&mut self) { + let castling_bitboard = self.pos.castles().castling_rights(); + let castling_rights = ( + castling_bitboard.contains(shakmaty::Square::A1), + castling_bitboard.contains(shakmaty::Square::H1), + castling_bitboard.contains(shakmaty::Square::A8), + castling_bitboard.contains(shakmaty::Square::H8), + ); + + self.castling_rights.push(Some(castling_rights)); + } + + fn push_legal_moves(&mut self) { + // Record the starting offset for the current position's legal moves. + self.legal_moves_offsets.push(self.flat_legal_moves.len()); + + let legal_moves_for_pos = self.pos.legal_moves(); + self.flat_legal_moves.reserve(legal_moves_for_pos.len()); + + for m in legal_moves_for_pos { + let uci_move_obj = UciMove::from_standard(m); + if let UciMove::Normal { + from, + to, + promotion: promo_opt, + } = uci_move_obj + { + self.flat_legal_moves.push(PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p_role| p_role as u8), + }); + } + } + } + + /// Record current board state to flat arrays for ParsedGames output. + fn push_board_state(&mut self) { + self.board_states + .extend_from_slice(&serialize_board(&self.pos)); + self.en_passant_states.push(get_en_passant_file(&self.pos)); + self.halfmove_clocks.push(get_halfmove_clock(&self.pos)); + self.turn_states.push(get_turn(&self.pos)); + let castling = get_castling_rights(&self.pos); + self.castling_states.extend_from_slice(&castling); + } + + fn update_position_status(&mut self) { + // TODO this checks legal_moves() a bunch of times + self.position_status = Some(PositionStatus { + is_checkmate: self.pos.is_checkmate(), + is_stalemate: self.pos.is_stalemate(), + legal_move_count: self.pos.legal_moves().len(), + is_game_over: self.pos.is_game_over(), + insufficient_material: ( + self.pos.has_insufficient_material(Color::White), + self.pos.has_insufficient_material(Color::Black), + ), + turn: match self.pos.turn() { + Color::White => true, + Color::Black => false, + }, + }); + } + + #[getter] + fn legal_moves(&self) -> Vec> { + let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); + if self.legal_moves_offsets.is_empty() { + return result; + } + + for i in 0..self.legal_moves_offsets.len() - 1 { + let start = self.legal_moves_offsets[i]; + let end = self.legal_moves_offsets[i + 1]; + result.push(self.flat_legal_moves[start..end].to_vec()); + } + + // Handle the last chunk + if let Some(&start) = self.legal_moves_offsets.last() { + result.push(self.flat_legal_moves[start..].to_vec()); + } + + result + } +} + +impl Visitor for MoveExtractor { + type Tags = Vec<(String, String)>; + type Movetext = (); + type Output = bool; + + fn begin_tags(&mut self) -> ControlFlow { + self.headers.clear(); + ControlFlow::Continue(Vec::with_capacity(10)) + } + + fn tag( + &mut self, + tags: &mut Self::Tags, + key: &[u8], + value: RawTag<'_>, + ) -> ControlFlow { + let key_str = String::from_utf8_lossy(key).into_owned(); + let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); + tags.push((key_str, value_str)); + ControlFlow::Continue(()) + } + + fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { + self.headers = tags; + self.moves.clear(); + self.flat_legal_moves.clear(); + self.legal_moves_offsets.clear(); + self.valid_moves = true; + self.comments.clear(); + self.evals.clear(); + self.clock_times.clear(); + self.castling_rights.clear(); + self.board_states.clear(); + self.en_passant_states.clear(); + self.halfmove_clocks.clear(); + self.turn_states.clear(); + self.castling_states.clear(); + + // Determine castling mode from Variant header (case-insensitive) + let castling_mode = self + .headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) + .and_then(|(_, v)| { + let v_lower = v.to_lowercase(); + if v_lower == "chess960" { + Some(CastlingMode::Chess960) + } else { + None + } + }) + .unwrap_or(CastlingMode::Standard); + + // Try to parse FEN from headers, fall back to default position + let fen_header = self + .headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) + .map(|(_, v)| v.as_str()); + + if let Some(fen_str) = fen_header { + match fen_str.parse::() { + Ok(fen) => match fen.into_position(castling_mode) { + Ok(pos) => self.pos = pos, + Err(e) => { + eprintln!("invalid FEN position: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + }, + Err(e) => { + eprintln!("failed to parse FEN: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + } + } else { + self.pos = Chess::default(); + } + + self.push_castling_bitboards(); + if self.store_legal_moves { + self.push_legal_moves(); + } + // Record initial board state for flat output + self.push_board_state(); + ControlFlow::Continue(()) + } + + // Roughly half the time during parsing is spent here in san() + fn san( + &mut self, + _movetext: &mut Self::Movetext, + san_plus: SanPlus, + ) -> ControlFlow { + if self.valid_moves { + // Most of the function time is spent calculating to_move() + match san_plus.san.to_move(&self.pos) { + Ok(m) => { + self.pos.play_unchecked(m); + if self.store_legal_moves { + self.push_legal_moves(); + } + // Record board state after move for flat output + self.push_board_state(); + let uci_move_obj = UciMove::from_standard(m); + + match uci_move_obj { + UciMove::Normal { + from, + to, + promotion: promo_opt, + } => { + let py_uci_move = PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p_role| p_role as u8), + }; + self.moves.push(py_uci_move); + self.push_castling_bitboards(); + + // Push placeholders to keep vectors in sync + self.comments.push(None); + self.evals.push(None); + self.clock_times.push(None); + } + _ => { + // This case handles UciMove::Put and UciMove::Null, + // which are not expected from standard PGN moves + // that PyUciMove is designed to represent. + eprintln!( + "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", + uci_move_obj + ); + self.valid_moves = false; + } + } + } + Err(err) => { + eprintln!("error in game: {} {}", err, san_plus); + self.valid_moves = false; + } + } + } + ControlFlow::Continue(()) + } + + fn comment( + &mut self, + _movetext: &mut Self::Movetext, + _comment: RawComment<'_>, + ) -> ControlFlow { + match parse_comments(_comment.as_bytes()) { + Ok((remaining_input, parsed_comments)) => { + if !remaining_input.is_empty() { + eprintln!("Unparsed remaining input: {:?}", remaining_input); + return ControlFlow::Continue(()); + } + + let mut move_comments = String::new(); + + for content in parsed_comments { + match content { + CommentContent::Text(text) => { + if !text.trim().is_empty() { + if !move_comments.is_empty() { + move_comments.push(' '); + } + move_comments.push_str(&text); + } + } + CommentContent::Tag(tag_content) => match tag_content { + ParsedTag::Eval(eval_value) => { + if let Some(last_eval) = self.evals.last_mut() { + *last_eval = Some(eval_value); + } + } + ParsedTag::Mate(mate_value) => { + if !move_comments.is_empty() && !move_comments.ends_with(' ') { + move_comments.push(' '); + } + move_comments.push_str(&format!("[Mate {}]", mate_value)); + } + ParsedTag::ClkTime { + hours, + minutes, + seconds, + } => { + if let Some(last_clk) = self.clock_times.last_mut() { + *last_clk = Some((hours, minutes, seconds)); + } + } + }, + } + } + + if let Some(last_comment) = self.comments.last_mut() { + *last_comment = Some(move_comments); + } + } + Err(e) => { + eprintln!("Error parsing comment: {:?}", e); + } + } + ControlFlow::Continue(()) + } + + fn begin_variation( + &mut self, + _movetext: &mut Self::Movetext, + ) -> ControlFlow { + ControlFlow::Continue(Skip(true)) // stay in the mainline + } + + fn outcome( + &mut self, + _movetext: &mut Self::Movetext, + _outcome: Outcome, + ) -> ControlFlow { + self.outcome = Some(match _outcome { + Outcome::Known(known) => match known { + KnownOutcome::Decisive { winner } => format!("{:?}", winner), + KnownOutcome::Draw => "Draw".to_string(), + }, + Outcome::Unknown => "Unknown".to_string(), + }); + self.update_position_status(); + ControlFlow::Continue(()) + } + + fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { + self.valid_moves + } +} From 08cf9c29eb77231bf472f2add6e25e169bd58b41 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 09:33:40 -0500 Subject: [PATCH 10/31] reduce copying --- src/flat_visitor.rs | 518 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 233 +++++++++----------- 2 files changed, 625 insertions(+), 126 deletions(-) create mode 100644 src/flat_visitor.rs diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs new file mode 100644 index 0000000..983bba7 --- /dev/null +++ b/src/flat_visitor.rs @@ -0,0 +1,518 @@ +//! Flat buffer visitor for direct SoA output. +//! +//! This module provides a memory-efficient parsing approach that writes +//! directly to flat buffers instead of allocating per-game Vec structures. +//! Used by `parse_games_flat` for optimal performance. + +use crate::board_serialization::{ + get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, +}; +use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; +use pgn_reader::{Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; +use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; +use std::collections::HashMap; +use std::ops::ControlFlow; + +/// Accumulated flat buffers for multiple parsed games. +/// +/// This struct holds all data in a struct-of-arrays layout, optimized for: +/// - Efficient thread-local accumulation during parallel parsing +/// - Fast merging of thread-local buffers via `extend_from_slice` +/// - Direct conversion to NumPy arrays without intermediate allocations +#[derive(Default, Clone)] +pub struct FlatBuffers { + // Board state arrays (one entry per position) + pub boards: Vec, // Flattened: 64 bytes per position + pub castling: Vec, // Flattened: 4 bools per position [K,Q,k,q] + pub en_passant: Vec, // Per position: -1 or file 0-7 + pub halfmove_clock: Vec, // Per position + pub turn: Vec, // Per position: true=white + + // Move arrays (one entry per move) + pub from_squares: Vec, + pub to_squares: Vec, + pub promotions: Vec, // -1 for no promotion + pub clocks: Vec, // NaN for missing + pub evals: Vec, // NaN for missing + + // Per-game data + pub move_counts: Vec, // Number of moves per game + pub position_counts: Vec, // Number of positions per game + pub is_checkmate: Vec, + pub is_stalemate: Vec, + pub is_insufficient: Vec, // Flattened: 2 bools per game [white, black] + pub legal_move_count: Vec, + pub valid: Vec, + pub headers: Vec>, +} + +impl FlatBuffers { + /// Create a new FlatBuffers with pre-allocated capacity. + /// + /// # Arguments + /// * `estimated_games` - Expected number of games + /// * `moves_per_game` - Expected average moves per game (default: 70) + pub fn with_capacity(estimated_games: usize, moves_per_game: usize) -> Self { + let estimated_moves = estimated_games * moves_per_game; + let estimated_positions = estimated_moves + estimated_games; // +1 initial position per game + + FlatBuffers { + // Board state arrays + boards: Vec::with_capacity(estimated_positions * 64), + castling: Vec::with_capacity(estimated_positions * 4), + en_passant: Vec::with_capacity(estimated_positions), + halfmove_clock: Vec::with_capacity(estimated_positions), + turn: Vec::with_capacity(estimated_positions), + + // Move arrays + from_squares: Vec::with_capacity(estimated_moves), + to_squares: Vec::with_capacity(estimated_moves), + promotions: Vec::with_capacity(estimated_moves), + clocks: Vec::with_capacity(estimated_moves), + evals: Vec::with_capacity(estimated_moves), + + // Per-game data + move_counts: Vec::with_capacity(estimated_games), + position_counts: Vec::with_capacity(estimated_games), + is_checkmate: Vec::with_capacity(estimated_games), + is_stalemate: Vec::with_capacity(estimated_games), + is_insufficient: Vec::with_capacity(estimated_games * 2), + legal_move_count: Vec::with_capacity(estimated_games), + valid: Vec::with_capacity(estimated_games), + headers: Vec::with_capacity(estimated_games), + } + } + + /// Merge another FlatBuffers into this one. + /// Used to combine thread-local buffers after parallel parsing. + pub fn merge(&mut self, other: FlatBuffers) { + // Board state arrays + self.boards.extend(other.boards); + self.castling.extend(other.castling); + self.en_passant.extend(other.en_passant); + self.halfmove_clock.extend(other.halfmove_clock); + self.turn.extend(other.turn); + + // Move arrays + self.from_squares.extend(other.from_squares); + self.to_squares.extend(other.to_squares); + self.promotions.extend(other.promotions); + self.clocks.extend(other.clocks); + self.evals.extend(other.evals); + + // Per-game data + self.move_counts.extend(other.move_counts); + self.position_counts.extend(other.position_counts); + self.is_checkmate.extend(other.is_checkmate); + self.is_stalemate.extend(other.is_stalemate); + self.is_insufficient.extend(other.is_insufficient); + self.legal_move_count.extend(other.legal_move_count); + self.valid.extend(other.valid); + self.headers.extend(other.headers); + } + + /// Number of games in this buffer. + pub fn num_games(&self) -> usize { + self.headers.len() + } + + /// Total number of moves across all games. + #[allow(dead_code)] + pub fn total_moves(&self) -> usize { + self.from_squares.len() + } + + /// Total number of positions across all games. + pub fn total_positions(&self) -> usize { + self.boards.len() / 64 + } + + /// Compute CSR-style offsets from counts. + pub fn compute_move_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.move_counts.len() + 1); + offsets.push(0); + for &count in &self.move_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } + + /// Compute CSR-style offsets from position counts. + pub fn compute_position_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.position_counts.len() + 1); + offsets.push(0); + for &count in &self.position_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } +} + +/// Visitor that writes directly to FlatBuffers. +/// +/// This visitor does not allocate any per-game Vec structures. +/// All data is appended directly to the shared FlatBuffers. +pub struct FlatVisitor<'a> { + buffers: &'a mut FlatBuffers, + pos: Chess, + valid_moves: bool, + current_headers: Vec<(String, String)>, + // Track counts for current game + current_move_count: u32, + current_position_count: u32, +} + +impl<'a> FlatVisitor<'a> { + pub fn new(buffers: &'a mut FlatBuffers) -> Self { + FlatVisitor { + buffers, + pos: Chess::default(), + valid_moves: true, + current_headers: Vec::with_capacity(10), + current_move_count: 0, + current_position_count: 0, + } + } + + /// Record current board state to flat buffers. + fn push_board_state(&mut self) { + self.buffers + .boards + .extend_from_slice(&serialize_board(&self.pos)); + let castling = get_castling_rights(&self.pos); + self.buffers.castling.extend_from_slice(&castling); + self.buffers.en_passant.push(get_en_passant_file(&self.pos)); + self.buffers + .halfmove_clock + .push(get_halfmove_clock(&self.pos)); + self.buffers.turn.push(get_turn(&self.pos)); + self.current_position_count += 1; + } + + /// Record move data to flat buffers. + fn push_move(&mut self, from: u8, to: u8, promotion: Option) { + self.buffers.from_squares.push(from); + self.buffers.to_squares.push(to); + self.buffers + .promotions + .push(promotion.map(|p| p as i8).unwrap_or(-1)); + // Push placeholders for clock and eval (will be overwritten by comment()) + self.buffers.clocks.push(f32::NAN); + self.buffers.evals.push(f32::NAN); + self.current_move_count += 1; + } + + /// Record final position status. + fn update_position_status(&mut self) { + self.buffers.is_checkmate.push(self.pos.is_checkmate()); + self.buffers.is_stalemate.push(self.pos.is_stalemate()); + self.buffers + .is_insufficient + .push(self.pos.has_insufficient_material(Color::White)); + self.buffers + .is_insufficient + .push(self.pos.has_insufficient_material(Color::Black)); + self.buffers + .legal_move_count + .push(self.pos.legal_moves().len() as u16); + } + + /// Finalize current game - record per-game data. + fn finalize_game(&mut self) { + self.buffers.move_counts.push(self.current_move_count); + self.buffers + .position_counts + .push(self.current_position_count); + self.buffers.valid.push(self.valid_moves); + + // Convert headers to HashMap + let header_map: HashMap = self.current_headers.drain(..).collect(); + self.buffers.headers.push(header_map); + } +} + +impl Visitor for FlatVisitor<'_> { + type Tags = Vec<(String, String)>; + type Movetext = (); + type Output = bool; + + fn begin_tags(&mut self) -> ControlFlow { + self.current_headers.clear(); + ControlFlow::Continue(Vec::with_capacity(10)) + } + + fn tag( + &mut self, + tags: &mut Self::Tags, + key: &[u8], + value: RawTag<'_>, + ) -> ControlFlow { + let key_str = String::from_utf8_lossy(key).into_owned(); + let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); + tags.push((key_str, value_str)); + ControlFlow::Continue(()) + } + + fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { + self.current_headers = tags; + self.valid_moves = true; + self.current_move_count = 0; + self.current_position_count = 0; + + // Determine castling mode from Variant header (case-insensitive) + let castling_mode = self + .current_headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) + .and_then(|(_, v)| { + let v_lower = v.to_lowercase(); + if v_lower == "chess960" { + Some(CastlingMode::Chess960) + } else { + None + } + }) + .unwrap_or(CastlingMode::Standard); + + // Try to parse FEN from headers, fall back to default position + let fen_header = self + .current_headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) + .map(|(_, v)| v.as_str()); + + if let Some(fen_str) = fen_header { + match fen_str.parse::() { + Ok(fen) => match fen.into_position(castling_mode) { + Ok(pos) => self.pos = pos, + Err(e) => { + eprintln!("invalid FEN position: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + }, + Err(e) => { + eprintln!("failed to parse FEN: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + } + } else { + self.pos = Chess::default(); + } + + // Record initial board state + self.push_board_state(); + ControlFlow::Continue(()) + } + + fn san( + &mut self, + _movetext: &mut Self::Movetext, + san_plus: SanPlus, + ) -> ControlFlow { + if self.valid_moves { + match san_plus.san.to_move(&self.pos) { + Ok(m) => { + self.pos.play_unchecked(m); + + // Record board state after move + self.push_board_state(); + + let uci_move_obj = UciMove::from_standard(m); + match uci_move_obj { + UciMove::Normal { + from, + to, + promotion, + } => { + self.push_move(from as u8, to as u8, promotion.map(|p| p as u8)); + } + _ => { + eprintln!( + "Unexpected UCI move type: {:?}. Game moves might be invalid.", + uci_move_obj + ); + self.valid_moves = false; + } + } + } + Err(err) => { + eprintln!("error in game: {} {}", err, san_plus); + self.valid_moves = false; + } + } + } + ControlFlow::Continue(()) + } + + fn comment( + &mut self, + _movetext: &mut Self::Movetext, + comment: RawComment<'_>, + ) -> ControlFlow { + if let Ok((_, parsed_comments)) = parse_comments(comment.as_bytes()) { + for content in parsed_comments { + match content { + CommentContent::Tag(tag_content) => match tag_content { + ParsedTag::Eval(eval_value) => { + // Update the last eval entry + if let Some(last_eval) = self.buffers.evals.last_mut() { + *last_eval = eval_value as f32; + } + } + ParsedTag::ClkTime { + hours, + minutes, + seconds, + } => { + // Convert to seconds and update the last clock entry + if let Some(last_clk) = self.buffers.clocks.last_mut() { + *last_clk = + hours as f32 * 3600.0 + minutes as f32 * 60.0 + seconds as f32; + } + } + ParsedTag::Mate(_) => { + // Mate scores are handled as comments, not numeric evals + } + }, + CommentContent::Text(_) => { + // Text comments are not stored in flat output + } + } + } + } + ControlFlow::Continue(()) + } + + fn begin_variation( + &mut self, + _movetext: &mut Self::Movetext, + ) -> ControlFlow { + ControlFlow::Continue(Skip(true)) // Skip variations, stay in mainline + } + + fn outcome( + &mut self, + _movetext: &mut Self::Movetext, + _outcome: Outcome, + ) -> ControlFlow { + self.update_position_status(); + ControlFlow::Continue(()) + } + + fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { + // Handle case where outcome() was not called (e.g., incomplete game) + if self.buffers.is_checkmate.len() < self.buffers.headers.len() + 1 { + self.update_position_status(); + } + self.finalize_game(); + self.valid_moves + } +} + +/// Parse a single game directly into FlatBuffers. +pub fn parse_game_to_flat(pgn: &str, buffers: &mut FlatBuffers) -> Result { + use pgn_reader::Reader; + use std::io::Cursor; + + let mut reader = Reader::new(Cursor::new(pgn)); + let mut visitor = FlatVisitor::new(buffers); + + match reader.read_game(&mut visitor) { + Ok(Some(valid)) => Ok(valid), + Ok(None) => Err("No game found in PGN".to_string()), + Err(err) => Err(format!("Parsing error: {}", err)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_simple_game() { + let pgn = r#"[Event "Test"] +[White "Player1"] +[Black "Player2"] +[Result "1-0"] + +1. e4 e5 2. Nf3 Nc6 1-0"#; + + let mut buffers = FlatBuffers::with_capacity(1, 70); + let result = parse_game_to_flat(pgn, &mut buffers); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.num_games(), 1); + assert_eq!(buffers.move_counts[0], 4); // 4 moves + assert_eq!(buffers.position_counts[0], 5); // 5 positions (initial + 4 moves) + assert_eq!(buffers.total_moves(), 4); + assert_eq!(buffers.total_positions(), 5); + } + + #[test] + fn test_parse_game_with_annotations() { + let pgn = r#"[Event "Test"] +[Result "1-0"] + +1. e4 { [%eval 0.17] [%clk 0:03:00] } 1... e5 { [%eval 0.19] [%clk 0:02:58] } 1-0"#; + + let mut buffers = FlatBuffers::with_capacity(1, 70); + let result = parse_game_to_flat(pgn, &mut buffers); + + assert!(result.is_ok()); + assert_eq!(buffers.total_moves(), 2); + + // Check that evals were parsed + assert!(!buffers.evals[0].is_nan()); + assert!((buffers.evals[0] - 0.17).abs() < 0.01); + assert!(!buffers.evals[1].is_nan()); + assert!((buffers.evals[1] - 0.19).abs() < 0.01); + + // Check that clocks were parsed (3 minutes = 180 seconds) + assert!(!buffers.clocks[0].is_nan()); + assert!((buffers.clocks[0] - 180.0).abs() < 0.01); + } + + #[test] + fn test_merge_buffers() { + let pgn1 = r#"[Event "Game1"] +[Result "1-0"] + +1. e4 e5 1-0"#; + + let pgn2 = r#"[Event "Game2"] +[Result "0-1"] + +1. d4 d5 2. c4 0-1"#; + + let mut buffers1 = FlatBuffers::with_capacity(1, 70); + let mut buffers2 = FlatBuffers::with_capacity(1, 70); + + parse_game_to_flat(pgn1, &mut buffers1).unwrap(); + parse_game_to_flat(pgn2, &mut buffers2).unwrap(); + + assert_eq!(buffers1.num_games(), 1); + assert_eq!(buffers2.num_games(), 1); + + buffers1.merge(buffers2); + + assert_eq!(buffers1.num_games(), 2); + assert_eq!(buffers1.total_moves(), 5); // 2 + 3 moves + assert_eq!(buffers1.move_counts, vec![2, 3]); + } + + #[test] + fn test_compute_offsets() { + let mut buffers = FlatBuffers::default(); + buffers.move_counts = vec![4, 6, 3]; + buffers.position_counts = vec![5, 7, 4]; + + let move_offsets = buffers.compute_move_offsets(); + assert_eq!(move_offsets, vec![0, 4, 10, 13]); + + let pos_offsets = buffers.compute_position_offsets(); + assert_eq!(pos_offsets, vec![0, 5, 12, 16]); + } +} diff --git a/src/lib.rs b/src/lib.rs index 94e0b46..c9cfbb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,17 +5,24 @@ use pyo3::prelude::*; use pyo3_arrow::PyChunkedArray; use rayon::prelude::*; use rayon::ThreadPoolBuilder; -use std::collections::HashMap; + use std::io::Cursor; mod board_serialization; mod comment_parsing; +mod flat_visitor; mod python_bindings; mod visitor; +use flat_visitor::{parse_game_to_flat, FlatBuffers}; use python_bindings::{ParsedGames, ParsedGamesIter, PositionStatus, PyGameView, PyUciMove}; use visitor::MoveExtractor; +/// Parse games from Arrow chunked array into flat NumPy arrays. +/// +/// This implementation uses thread-local FlatBuffers that are merged after +/// parallel parsing, avoiding the overhead of per-game Vec allocations +/// followed by a sequential copy loop. #[pyfunction] #[pyo3(signature = (pgn_chunked_array, num_threads=None))] fn parse_games_flat( @@ -23,158 +30,132 @@ fn parse_games_flat( pgn_chunked_array: PyChunkedArray, num_threads: Option, ) -> PyResult { - // 1. Parse all games using existing logic - let extractors = _parse_game_moves_from_arrow_chunks_native( - &pgn_chunked_array, - num_threads, - false, // store_legal_moves = false for performance - ) - .map_err(|e| PyErr::new::(e))?; - - let n_games = extractors.len(); - - // 2. Compute move counts and position counts from actual recorded data - let move_counts: Vec = extractors.iter().map(|e| e.moves.len() as u32).collect(); - - // Position counts derived from actual board_states data - let position_counts: Vec = extractors - .iter() - .map(|e| (e.board_states.len() / 64) as u32) - .collect(); - - // Build move offsets - let mut move_offsets_vec: Vec = Vec::with_capacity(n_games + 1); - move_offsets_vec.push(0); - for &count in &move_counts { - move_offsets_vec.push(move_offsets_vec.last().unwrap() + count); - } + let num_threads = num_threads.unwrap_or_else(num_cpus::get); - // Build position offsets - let mut position_offsets_vec: Vec = Vec::with_capacity(n_games + 1); - position_offsets_vec.push(0); - for &count in &position_counts { - position_offsets_vec.push(position_offsets_vec.last().unwrap() + count); + // Extract PGN strings from Arrow chunks + let mut num_elements = 0; + for chunk in pgn_chunked_array.chunks() { + num_elements += chunk.len(); } - let total_moves = *move_offsets_vec.last().unwrap() as usize; - let total_positions = *position_offsets_vec.last().unwrap() as usize; - - // 3. Pre-allocate flat vectors - let mut boards_vec: Vec = Vec::with_capacity(total_positions * 64); - let mut castling_vec: Vec = Vec::with_capacity(total_positions * 4); - let mut en_passant_vec: Vec = Vec::with_capacity(total_positions); - let mut halfmove_clock_vec: Vec = Vec::with_capacity(total_positions); - let mut turn_vec: Vec = Vec::with_capacity(total_positions); - - let mut from_squares_vec: Vec = Vec::with_capacity(total_moves); - let mut to_squares_vec: Vec = Vec::with_capacity(total_moves); - let mut promotions_vec: Vec = Vec::with_capacity(total_moves); - let mut clocks_vec: Vec = Vec::with_capacity(total_moves); - let mut evals_vec: Vec = Vec::with_capacity(total_moves); - - let mut is_checkmate_vec: Vec = Vec::with_capacity(n_games); - let mut is_stalemate_vec: Vec = Vec::with_capacity(n_games); - let mut is_insufficient_vec: Vec = Vec::with_capacity(n_games * 2); - let mut legal_move_count_vec: Vec = Vec::with_capacity(n_games); - let mut valid_vec: Vec = Vec::with_capacity(n_games); - let mut headers_vec: Vec> = Vec::with_capacity(n_games); - - // 4. Copy data from each extractor - for extractor in &extractors { - // Board states - boards_vec.extend_from_slice(&extractor.board_states); - castling_vec.extend(extractor.castling_states.iter().copied()); - en_passant_vec.extend_from_slice(&extractor.en_passant_states); - halfmove_clock_vec.extend_from_slice(&extractor.halfmove_clocks); - turn_vec.extend(extractor.turn_states.iter().copied()); - - // Moves - for m in &extractor.moves { - from_squares_vec.push(m.from_square); - to_squares_vec.push(m.to_square); - promotions_vec.push(m.promotion.map(|p| p as i8).unwrap_or(-1)); + let mut pgn_str_slices: Vec<&str> = Vec::with_capacity(num_elements); + for chunk in pgn_chunked_array.chunks() { + if let Some(string_array) = chunk.as_any().downcast_ref::() { + for i in 0..string_array.len() { + if string_array.is_valid(i) { + pgn_str_slices.push(string_array.value(i)); + } + } + } else if let Some(large_string_array) = chunk.as_any().downcast_ref::() { + for i in 0..large_string_array.len() { + if large_string_array.is_valid(i) { + pgn_str_slices.push(large_string_array.value(i)); + } + } + } else { + return Err(PyErr::new::(format!( + "Unsupported array type in ChunkedArray: {:?}", + chunk.data_type() + ))); } + } - // Clocks (convert to seconds) - for clock in &extractor.clock_times { - clocks_vec.push( - clock - .map(|(h, m, s)| h as f32 * 3600.0 + m as f32 * 60.0 + s as f32) - .unwrap_or(f32::NAN), - ); - } + let n_games = pgn_str_slices.len(); - // Evals (convert mate values to large numbers) - for eval in &extractor.evals { - evals_vec.push(eval.map(|e| e as f32).unwrap_or(f32::NAN)); - } + // Build thread pool + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .map_err(|e| { + PyErr::new::(format!( + "Failed to build thread pool: {}", + e + )) + })?; + + // Estimate capacity: ~70 moves per game + let games_per_thread = (n_games + num_threads - 1) / num_threads; + let moves_per_game = 70; + + // Parse in parallel using fold_with for thread-local buffer accumulation. + // fold_with gives each worker thread its own FlatBuffers instance to accumulate into. + // This is more efficient than fold() which creates buffers per work-stealing chunk. + let thread_results: Vec = thread_pool.install(|| { + pgn_str_slices + .par_iter() + .fold_with( + FlatBuffers::with_capacity(games_per_thread, moves_per_game), + |mut buffers, &pgn| { + let _ = parse_game_to_flat(pgn, &mut buffers); + buffers + }, + ) + .collect() + }); - // Final position status - if let Some(ref status) = extractor.position_status { - is_checkmate_vec.push(status.is_checkmate); - is_stalemate_vec.push(status.is_stalemate); - is_insufficient_vec.push(status.insufficient_material.0); - is_insufficient_vec.push(status.insufficient_material.1); - legal_move_count_vec.push(status.legal_move_count as u16); - } else { - // No status computed - use defaults - is_checkmate_vec.push(false); - is_stalemate_vec.push(false); - is_insufficient_vec.push(false); - is_insufficient_vec.push(false); - legal_move_count_vec.push(0); - } + // Merge all thread-local buffers (this is O(num_threads) not O(num_games)) + let mut combined_buffers = FlatBuffers::default(); + for buf in thread_results { + combined_buffers.merge(buf); + } - // Valid flag - valid_vec.push(extractor.valid_moves); + // Convert FlatBuffers to ParsedGames with NumPy arrays + flat_buffers_to_parsed_games(py, combined_buffers) +} - // Headers as HashMap - let header_map: HashMap = extractor - .headers - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(); - headers_vec.push(header_map); - } +/// Convert FlatBuffers to ParsedGames with NumPy arrays. +fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResult { + let n_games = buffers.num_games(); + let total_positions = buffers.total_positions(); + + // Compute offsets + let move_offsets_vec = buffers.compute_move_offsets(); + let position_offsets_vec = buffers.compute_position_offsets(); - // 5. Convert to numpy arrays + // Convert to NumPy arrays // Boards: reshape from flat to (N_positions, 8, 8) - let boards_array = PyArray1::from_vec(py, boards_vec); + let boards_array = PyArray1::from_vec(py, buffers.boards); let boards_reshaped = boards_array .reshape([total_positions, 8, 8]) .map_err(|e| PyErr::new::(e.to_string()))?; // Castling: reshape from flat to (N_positions, 4) - let castling_array = PyArray1::from_vec(py, castling_vec); + let castling_array = PyArray1::from_vec(py, buffers.castling); let castling_reshaped = castling_array .reshape([total_positions, 4]) .map_err(|e| PyErr::new::(e.to_string()))?; // 1D arrays - let en_passant_array = PyArray1::from_vec(py, en_passant_vec); - let halfmove_clock_array = PyArray1::from_vec(py, halfmove_clock_vec); - let turn_array = PyArray1::from_vec(py, turn_vec); + let en_passant_array = PyArray1::from_vec(py, buffers.en_passant); + let halfmove_clock_array = PyArray1::from_vec(py, buffers.halfmove_clock); + let turn_array = PyArray1::from_vec(py, buffers.turn); - let from_squares_array = PyArray1::from_vec(py, from_squares_vec); - let to_squares_array = PyArray1::from_vec(py, to_squares_vec); - let promotions_array = PyArray1::from_vec(py, promotions_vec); - let clocks_array = PyArray1::from_vec(py, clocks_vec); - let evals_array = PyArray1::from_vec(py, evals_vec); + let from_squares_array = PyArray1::from_vec(py, buffers.from_squares); + let to_squares_array = PyArray1::from_vec(py, buffers.to_squares); + let promotions_array = PyArray1::from_vec(py, buffers.promotions); + let clocks_array = PyArray1::from_vec(py, buffers.clocks); + let evals_array = PyArray1::from_vec(py, buffers.evals); let move_offsets_array = PyArray1::from_vec(py, move_offsets_vec); let position_offsets_array = PyArray1::from_vec(py, position_offsets_vec); - let is_checkmate_array = PyArray1::from_vec(py, is_checkmate_vec); - let is_stalemate_array = PyArray1::from_vec(py, is_stalemate_vec); + let is_checkmate_array = PyArray1::from_vec(py, buffers.is_checkmate); + let is_stalemate_array = PyArray1::from_vec(py, buffers.is_stalemate); // is_insufficient: reshape to (N_games, 2) - let is_insufficient_array = PyArray1::from_vec(py, is_insufficient_vec); - let is_insufficient_reshaped = is_insufficient_array - .reshape([n_games, 2]) - .map_err(|e| PyErr::new::(e.to_string()))?; - - let legal_move_count_array = PyArray1::from_vec(py, legal_move_count_vec); - let valid_array = PyArray1::from_vec(py, valid_vec); + let is_insufficient_array = PyArray1::from_vec(py, buffers.is_insufficient); + let is_insufficient_reshaped = if n_games > 0 { + is_insufficient_array + .reshape([n_games, 2]) + .map_err(|e| PyErr::new::(e.to_string()))? + } else { + is_insufficient_array + .reshape([0, 2]) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let legal_move_count_array = PyArray1::from_vec(py, buffers.legal_move_count); + let valid_array = PyArray1::from_vec(py, buffers.valid); Ok(ParsedGames { boards: boards_reshaped.unbind().into_any(), @@ -194,7 +175,7 @@ fn parse_games_flat( is_insufficient: is_insufficient_reshaped.unbind().into_any(), legal_move_count: legal_move_count_array.unbind().into_any(), valid: valid_array.unbind().into_any(), - headers: headers_vec, + headers: buffers.headers, }) } From 7f0d8a98d592e539ee72faf806f4c22ea1354c29 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 10:07:51 -0500 Subject: [PATCH 11/31] profiling, fancier merge --- benches/parquet_bench.rs | 89 +++++++++++++++++++++++++++++++++++++++- src/flat_visitor.rs | 74 +++++++++++++++++++++++++++++++++ src/lib.rs | 9 ++-- 3 files changed, 165 insertions(+), 7 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index fb69596..ec10732 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -1,10 +1,14 @@ use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use rayon::prelude::*; +use rayon::ThreadPoolBuilder; use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::parse_multiple_games_native; +use rust_pgn_reader_python_binding::{ + parse_game_to_flat, parse_multiple_games_native, FlatBuffers, +}; pub fn bench_parquet() { let file_path = "2013-07-train-00000-of-00001.parquet"; @@ -58,6 +62,89 @@ pub fn bench_parquet() { println!("Time after checks: {:?}", duration2); } +/// Read movetexts from a parquet file. +fn read_movetexts_from_parquet(file_path: &str) -> Vec { + let file = File::open(Path::new(file_path)).expect("Unable to open file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(file) + .expect("Failed to create ParquetRecordBatchReaderBuilder"); + let mut reader = builder + .build() + .expect("Failed to build ParquetRecordBatchReader"); + + let mut movetexts = Vec::new(); + while let Some(batch) = reader + .next() + .transpose() + .expect("Error reading record batch") + { + if let Some(array) = batch + .column_by_name("movetext") + .and_then(|col| col.as_any().downcast_ref::()) + { + for i in 0..array.len() { + if array.is_valid(i) { + movetexts.push(array.value(i).to_string()); + } + } + } else { + panic!("movetext column not found or not a StringArray"); + } + } + movetexts +} + +/// Benchmark using the flat API (FlatBuffers + parallel fold_with). +pub fn bench_parquet_flat() { + let file_path = "2013-07-train-00000-of-00001.parquet"; + + let movetexts = read_movetexts_from_parquet(file_path); + println!("Read {} rows.", movetexts.len()); + + let num_threads = num_cpus::get(); + let n_games = movetexts.len(); + let games_per_thread = (n_games + num_threads - 1) / num_threads; + let moves_per_game = 70; + + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .expect("Failed to build Rayon thread pool"); + + let start = Instant::now(); + + // Parse in parallel using fold_with for thread-local buffer accumulation + let thread_results: Vec = thread_pool.install(|| { + movetexts + .par_iter() + .fold_with( + FlatBuffers::with_capacity(games_per_thread, moves_per_game), + |mut buffers, pgn| { + let _ = parse_game_to_flat(pgn, &mut buffers); + buffers + }, + ) + .collect() + }); + + let duration_parallel = start.elapsed(); + println!("Parallel parsing time: {:?}", duration_parallel); + + // Merge all thread-local buffers with pre-allocation + let combined_buffers = FlatBuffers::merge_all(thread_results); + + let duration_total = start.elapsed(); + println!("Total time (including merge): {:?}", duration_total); + println!( + "Parsed {} games, {} total positions.", + combined_buffers.num_games(), + combined_buffers.total_positions() + ); +} + fn main() { + println!("=== MoveExtractor API ===\n"); bench_parquet(); + + println!("\n=== Flat API ===\n"); + bench_parquet_flat(); } diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 983bba7..0d14027 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -85,6 +85,9 @@ impl FlatBuffers { /// Merge another FlatBuffers into this one. /// Used to combine thread-local buffers after parallel parsing. + /// + /// Note: For merging multiple buffers, prefer `merge_all` which pre-allocates + /// to avoid repeated reallocations. pub fn merge(&mut self, other: FlatBuffers) { // Board state arrays self.boards.extend(other.boards); @@ -111,6 +114,77 @@ impl FlatBuffers { self.headers.extend(other.headers); } + /// Merge multiple FlatBuffers efficiently by pre-allocating total capacity. + /// + /// This avoids the repeated reallocations that occur when calling `merge` + /// in a loop starting from an empty buffer. + pub fn merge_all(buffers: Vec) -> FlatBuffers { + if buffers.is_empty() { + return FlatBuffers::default(); + } + if buffers.len() == 1 { + return buffers.into_iter().next().unwrap(); + } + + // Calculate total sizes + let total_games: usize = buffers.iter().map(|b| b.headers.len()).sum(); + let total_positions: usize = buffers.iter().map(|b| b.boards.len() / 64).sum(); + let total_moves: usize = buffers.iter().map(|b| b.from_squares.len()).sum(); + + // Pre-allocate with exact capacity + let mut combined = FlatBuffers { + // Board state arrays + boards: Vec::with_capacity(total_positions * 64), + castling: Vec::with_capacity(total_positions * 4), + en_passant: Vec::with_capacity(total_positions), + halfmove_clock: Vec::with_capacity(total_positions), + turn: Vec::with_capacity(total_positions), + + // Move arrays + from_squares: Vec::with_capacity(total_moves), + to_squares: Vec::with_capacity(total_moves), + promotions: Vec::with_capacity(total_moves), + clocks: Vec::with_capacity(total_moves), + evals: Vec::with_capacity(total_moves), + + // Per-game data + move_counts: Vec::with_capacity(total_games), + position_counts: Vec::with_capacity(total_games), + is_checkmate: Vec::with_capacity(total_games), + is_stalemate: Vec::with_capacity(total_games), + is_insufficient: Vec::with_capacity(total_games * 2), + legal_move_count: Vec::with_capacity(total_games), + valid: Vec::with_capacity(total_games), + headers: Vec::with_capacity(total_games), + }; + + // Now merge - no reallocations will occur + for buf in buffers { + combined.boards.extend(buf.boards); + combined.castling.extend(buf.castling); + combined.en_passant.extend(buf.en_passant); + combined.halfmove_clock.extend(buf.halfmove_clock); + combined.turn.extend(buf.turn); + + combined.from_squares.extend(buf.from_squares); + combined.to_squares.extend(buf.to_squares); + combined.promotions.extend(buf.promotions); + combined.clocks.extend(buf.clocks); + combined.evals.extend(buf.evals); + + combined.move_counts.extend(buf.move_counts); + combined.position_counts.extend(buf.position_counts); + combined.is_checkmate.extend(buf.is_checkmate); + combined.is_stalemate.extend(buf.is_stalemate); + combined.is_insufficient.extend(buf.is_insufficient); + combined.legal_move_count.extend(buf.legal_move_count); + combined.valid.extend(buf.valid); + combined.headers.extend(buf.headers); + } + + combined + } + /// Number of games in this buffer. pub fn num_games(&self) -> usize { self.headers.len() diff --git a/src/lib.rs b/src/lib.rs index c9cfbb4..c67b459 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ mod flat_visitor; mod python_bindings; mod visitor; -use flat_visitor::{parse_game_to_flat, FlatBuffers}; +pub use flat_visitor::{parse_game_to_flat, FlatBuffers}; use python_bindings::{ParsedGames, ParsedGamesIter, PositionStatus, PyGameView, PyUciMove}; use visitor::MoveExtractor; @@ -93,11 +93,8 @@ fn parse_games_flat( .collect() }); - // Merge all thread-local buffers (this is O(num_threads) not O(num_games)) - let mut combined_buffers = FlatBuffers::default(); - for buf in thread_results { - combined_buffers.merge(buf); - } + // Merge all thread-local buffers with pre-allocation (avoids repeated reallocations) + let combined_buffers = FlatBuffers::merge_all(thread_results); // Convert FlatBuffers to ParsedGames with NumPy arrays flat_buffers_to_parsed_games(py, combined_buffers) From fecc90b4c190b7818e07f4ef7c7052f57d45a7c4 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 20:11:38 -0500 Subject: [PATCH 12/31] Tweak benchmark to better match Python code --- Cargo.toml | 8 +- benches/parquet_bench.rs | 173 ++++++++++++++++++++++----------------- src/lib.rs | 2 +- 3 files changed, 107 insertions(+), 76 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aa9257b..4cd7471 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,15 +19,19 @@ num_cpus = "1.17" arrow-array = "57" pyo3-arrow = "0.15" numpy = "0.27" +parquet = "57" +arrow = "57" [dev-dependencies] criterion = "0.8" -parquet = "57" -arrow = "57" [[bench]] harness = false name = "parquet_bench" +[[bin]] +name = "parquet_bench" +path = "benches/parquet_bench.rs" + [profile.bench] debug = true diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index ec10732..4856818 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -1,3 +1,11 @@ +//! Benchmark for PGN parsing APIs, designed to mirror the Python workflow. +//! +//! This benchmark emulates the call graph of: +//! - `parse_game_moves_arrow_chunked_array()` (Arrow API → Vec) +//! - `parse_games_flat()` (Flat API → FlatBuffers with NumPy-like arrays) +//! +//! Both use zero-copy &str slices from Arrow StringArrays, matching Python's behavior. + use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use rayon::prelude::*; @@ -6,14 +14,13 @@ use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::{ - parse_game_to_flat, parse_multiple_games_native, FlatBuffers, -}; +use rust_pgn_reader_python_binding::{parse_game_to_flat, parse_single_game_native, FlatBuffers}; -pub fn bench_parquet() { - let file_path = "2013-07-train-00000-of-00001.parquet"; +const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; - // Open the Parquet file +/// Read parquet file and return the raw Arrow StringArrays. +/// This preserves Arrow's memory layout for zero-copy string access. +fn read_parquet_to_string_arrays(file_path: &str) -> Vec { let file = File::open(Path::new(file_path)).expect("Unable to open file"); let builder = ParquetRecordBatchReaderBuilder::try_new(file) .expect("Failed to create ParquetRecordBatchReaderBuilder"); @@ -21,104 +28,124 @@ pub fn bench_parquet() { .build() .expect("Failed to build ParquetRecordBatchReader"); - // Process record batches - let mut movetexts = Vec::new(); + let mut arrays = Vec::new(); while let Some(batch) = reader .next() .transpose() .expect("Error reading record batch") { - // Extract "movetext" column from the record batch if let Some(array) = batch .column_by_name("movetext") .and_then(|col| col.as_any().downcast_ref::()) { - for i in 0..array.len() { - if array.is_valid(i) { - movetexts.push(array.value(i).to_string()); - } - } + // Clone the StringArray to own it (Arrow uses Arc internally, so this is cheap) + arrays.push(array.clone()); } else { panic!("movetext column not found or not a StringArray"); } } - - println!("Read {} rows.", movetexts.len()); - // Measure start time - let start = Instant::now(); - - let result = parse_multiple_games_native(&movetexts, None, false); - - let duration = start.elapsed(); - println!("Time taken: {:?}", duration); - - match result { - Ok(parsed) => println!("Parsed {} games.", parsed.len()), - Err(err) => eprintln!("Error parsing games: {}", err), - } - - let duration2 = start.elapsed(); - - println!("Time after checks: {:?}", duration2); + arrays } -/// Read movetexts from a parquet file. -fn read_movetexts_from_parquet(file_path: &str) -> Vec { - let file = File::open(Path::new(file_path)).expect("Unable to open file"); - let builder = ParquetRecordBatchReaderBuilder::try_new(file) - .expect("Failed to create ParquetRecordBatchReaderBuilder"); - let mut reader = builder - .build() - .expect("Failed to build ParquetRecordBatchReader"); - - let mut movetexts = Vec::new(); - while let Some(batch) = reader - .next() - .transpose() - .expect("Error reading record batch") - { - if let Some(array) = batch - .column_by_name("movetext") - .and_then(|col| col.as_any().downcast_ref::()) - { - for i in 0..array.len() { - if array.is_valid(i) { - movetexts.push(array.value(i).to_string()); - } +/// Extract &str slices from Arrow StringArrays (zero-copy). +/// This mirrors the extraction logic in `_parse_game_moves_from_arrow_chunks_native` +/// and `parse_games_flat` in lib.rs. +fn extract_str_slices<'a>(arrays: &'a [StringArray]) -> Vec<&'a str> { + let total_len: usize = arrays.iter().map(|a| a.len()).sum(); + let mut slices = Vec::with_capacity(total_len); + + for array in arrays { + for i in 0..array.len() { + if array.is_valid(i) { + slices.push(array.value(i)); } - } else { - panic!("movetext column not found or not a StringArray"); } } - movetexts + slices } -/// Benchmark using the flat API (FlatBuffers + parallel fold_with). -pub fn bench_parquet_flat() { - let file_path = "2013-07-train-00000-of-00001.parquet"; +/// Benchmark the Arrow API workflow. +/// +/// This mirrors `parse_game_moves_arrow_chunked_array()` from Python: +/// 1. Read parquet to Arrow arrays +/// 2. Extract &str slices from StringArray (like the Python-bound function does) +/// 3. Parse each game in parallel → Vec +pub fn bench_arrow_api() { + // Step 1: Read parquet to Arrow StringArrays + let arrays = read_parquet_to_string_arrays(FILE_PATH); + + // Step 2: Extract &str slices (zero-copy, mirrors Arrow chunk iteration) + let pgn_slices = extract_str_slices(&arrays); + println!("Read {} games from parquet.", pgn_slices.len()); + + // Step 3: Build thread pool (same pattern as lib.rs) + let num_threads = num_cpus::get(); + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .expect("Failed to build Rayon thread pool"); + + // Step 4: Parse in parallel → Vec + // This mirrors _parse_game_moves_from_arrow_chunks_native + let start = Instant::now(); + + let extractors: Vec<_> = thread_pool + .install(|| { + pgn_slices + .par_iter() + .map(|&pgn| parse_single_game_native(pgn, false)) + .collect::, _>>() + }) + .expect("Parsing failed"); - let movetexts = read_movetexts_from_parquet(file_path); - println!("Read {} rows.", movetexts.len()); + let duration = start.elapsed(); + + // Report results + let total_moves: usize = extractors.iter().map(|e| e.moves.len()).sum(); + println!("Parsing time: {:?}", duration); + println!( + "Parsed {} games, {} total moves.", + extractors.len(), + total_moves + ); +} +/// Benchmark the Flat API workflow. +/// +/// This mirrors `parse_games_flat()` from Python: +/// 1. Read parquet to Arrow arrays +/// 2. Extract &str slices from StringArray +/// 3. Parse in parallel with fold_with → thread-local FlatBuffers +/// 4. Merge all FlatBuffers into one +pub fn bench_flat_api() { + // Step 1: Read parquet to Arrow StringArrays + let arrays = read_parquet_to_string_arrays(FILE_PATH); + + // Step 2: Extract &str slices (zero-copy) + let pgn_slices = extract_str_slices(&arrays); + println!("Read {} games from parquet.", pgn_slices.len()); + + // Step 3: Build thread pool and compute capacity estimates let num_threads = num_cpus::get(); - let n_games = movetexts.len(); + let n_games = pgn_slices.len(); let games_per_thread = (n_games + num_threads - 1) / num_threads; - let moves_per_game = 70; + let moves_per_game = 70; // Estimate ~70 moves per game let thread_pool = ThreadPoolBuilder::new() .num_threads(num_threads) .build() .expect("Failed to build Rayon thread pool"); + // Step 4: Parse in parallel using fold_with for thread-local buffer accumulation + // This is exactly the pattern used in parse_games_flat() in lib.rs let start = Instant::now(); - // Parse in parallel using fold_with for thread-local buffer accumulation let thread_results: Vec = thread_pool.install(|| { - movetexts + pgn_slices .par_iter() .fold_with( FlatBuffers::with_capacity(games_per_thread, moves_per_game), - |mut buffers, pgn| { + |mut buffers, &pgn| { let _ = parse_game_to_flat(pgn, &mut buffers); buffers }, @@ -129,7 +156,7 @@ pub fn bench_parquet_flat() { let duration_parallel = start.elapsed(); println!("Parallel parsing time: {:?}", duration_parallel); - // Merge all thread-local buffers with pre-allocation + // Step 5: Merge all thread-local buffers (mirrors parse_games_flat) let combined_buffers = FlatBuffers::merge_all(thread_results); let duration_total = start.elapsed(); @@ -142,9 +169,9 @@ pub fn bench_parquet_flat() { } fn main() { - println!("=== MoveExtractor API ===\n"); - bench_parquet(); + println!("=== Arrow API (MoveExtractor) ===\n"); + bench_arrow_api(); - println!("\n=== Flat API ===\n"); - bench_parquet_flat(); + println!("\n=== Flat API (FlatBuffers) ===\n"); + bench_flat_api(); } diff --git a/src/lib.rs b/src/lib.rs index c67b459..a60b4a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ mod visitor; pub use flat_visitor::{parse_game_to_flat, FlatBuffers}; use python_bindings::{ParsedGames, ParsedGamesIter, PositionStatus, PyGameView, PyUciMove}; -use visitor::MoveExtractor; +pub use visitor::MoveExtractor; /// Parse games from Arrow chunked array into flat NumPy arrays. /// From 3617dfdf0a4e240c0c0e60e3b8ad8de161438513 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 21:18:03 -0500 Subject: [PATCH 13/31] Performance optimization - reduce chunk size for flat API, option to not store board state for regular API --- benches/parquet_bench.rs | 60 ++++++++++++----- src/lib.rs | 137 ++++++++++++++++++++++++++++----------- src/visitor.rs | 44 ++++++++----- 3 files changed, 168 insertions(+), 73 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 4856818..8907c9a 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -18,6 +18,11 @@ use rust_pgn_reader_python_binding::{parse_game_to_flat, parse_single_game_nativ const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; +/// Chunk multiplier for explicit chunking in Flat API. +/// 1 = exactly num_threads chunks (minimal merge overhead) +/// Higher values provide better load balancing at cost of more buffers to merge. +const CHUNK_MULTIPLIER: usize = 1; + /// Read parquet file and return the raw Arrow StringArrays. /// This preserves Arrow's memory layout for zero-copy string access. fn read_parquet_to_string_arrays(file_path: &str) -> Vec { @@ -70,13 +75,17 @@ fn extract_str_slices<'a>(arrays: &'a [StringArray]) -> Vec<&'a str> { /// 1. Read parquet to Arrow arrays /// 2. Extract &str slices from StringArray (like the Python-bound function does) /// 3. Parse each game in parallel → Vec -pub fn bench_arrow_api() { +/// +/// Args: +/// - store_board_states: Whether to populate board state vectors (for benchmarking overhead) +pub fn bench_arrow_api(store_board_states: bool) { // Step 1: Read parquet to Arrow StringArrays let arrays = read_parquet_to_string_arrays(FILE_PATH); // Step 2: Extract &str slices (zero-copy, mirrors Arrow chunk iteration) let pgn_slices = extract_str_slices(&arrays); println!("Read {} games from parquet.", pgn_slices.len()); + println!("store_board_states: {}", store_board_states); // Step 3: Build thread pool (same pattern as lib.rs) let num_threads = num_cpus::get(); @@ -93,7 +102,7 @@ pub fn bench_arrow_api() { .install(|| { pgn_slices .par_iter() - .map(|&pgn| parse_single_game_native(pgn, false)) + .map(|&pgn| parse_single_game_native(pgn, false, store_board_states)) .collect::, _>>() }) .expect("Parsing failed"); @@ -115,7 +124,7 @@ pub fn bench_arrow_api() { /// This mirrors `parse_games_flat()` from Python: /// 1. Read parquet to Arrow arrays /// 2. Extract &str slices from StringArray -/// 3. Parse in parallel with fold_with → thread-local FlatBuffers +/// 3. Parse in parallel with explicit chunking (par_chunks) → fixed number of FlatBuffers /// 4. Merge all FlatBuffers into one pub fn bench_flat_api() { // Step 1: Read parquet to Arrow StringArrays @@ -128,36 +137,48 @@ pub fn bench_flat_api() { // Step 3: Build thread pool and compute capacity estimates let num_threads = num_cpus::get(); let n_games = pgn_slices.len(); - let games_per_thread = (n_games + num_threads - 1) / num_threads; let moves_per_game = 70; // Estimate ~70 moves per game + // Calculate chunk size for explicit chunking + let num_chunks = num_threads * CHUNK_MULTIPLIER; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; // ceiling division + let chunk_size = chunk_size.max(1); + let games_per_chunk = chunk_size; + + println!( + "Using {} threads, {} chunks, {} games/chunk", + num_threads, num_chunks, games_per_chunk + ); + let thread_pool = ThreadPoolBuilder::new() .num_threads(num_threads) .build() .expect("Failed to build Rayon thread pool"); - // Step 4: Parse in parallel using fold_with for thread-local buffer accumulation - // This is exactly the pattern used in parse_games_flat() in lib.rs + // Step 4: Parse in parallel using par_chunks for explicit, fixed-size chunking + // This creates exactly ceil(n_games / chunk_size) FlatBuffers instances, + // avoiding the allocation storm from Rayon's dynamic work-stealing. let start = Instant::now(); - let thread_results: Vec = thread_pool.install(|| { + let chunk_results: Vec = thread_pool.install(|| { pgn_slices - .par_iter() - .fold_with( - FlatBuffers::with_capacity(games_per_thread, moves_per_game), - |mut buffers, &pgn| { + .par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = FlatBuffers::with_capacity(games_per_chunk, moves_per_game); + for &pgn in chunk { let _ = parse_game_to_flat(pgn, &mut buffers); - buffers - }, - ) + } + buffers + }) .collect() }); let duration_parallel = start.elapsed(); println!("Parallel parsing time: {:?}", duration_parallel); + println!("Created {} FlatBuffers to merge", chunk_results.len()); - // Step 5: Merge all thread-local buffers (mirrors parse_games_flat) - let combined_buffers = FlatBuffers::merge_all(thread_results); + // Step 5: Merge all chunk buffers (mirrors parse_games_flat) + let combined_buffers = FlatBuffers::merge_all(chunk_results); let duration_total = start.elapsed(); println!("Total time (including merge): {:?}", duration_total); @@ -169,8 +190,11 @@ pub fn bench_flat_api() { } fn main() { - println!("=== Arrow API (MoveExtractor) ===\n"); - bench_arrow_api(); + println!("=== Arrow API (MoveExtractor, store_board_states=false) ===\n"); + bench_arrow_api(false); + + println!("\n=== Arrow API (MoveExtractor, store_board_states=true) ===\n"); + bench_arrow_api(true); println!("\n=== Flat API (FlatBuffers) ===\n"); bench_flat_api(); diff --git a/src/lib.rs b/src/lib.rs index a60b4a5..eb27872 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,17 +20,25 @@ pub use visitor::MoveExtractor; /// Parse games from Arrow chunked array into flat NumPy arrays. /// -/// This implementation uses thread-local FlatBuffers that are merged after -/// parallel parsing, avoiding the overhead of per-game Vec allocations -/// followed by a sequential copy loop. +/// This implementation uses explicit chunking with a fixed number of chunks +/// (num_chunks = num_threads * chunk_multiplier) to avoid the allocation storm +/// caused by Rayon's dynamic work-stealing with fold_with. +/// +/// Each chunk gets exactly one FlatBuffers instance, drastically reducing +/// the number of allocations and making merge_all much faster. #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None))] +#[pyo3(signature = (pgn_chunked_array, num_threads=None, chunk_multiplier=None))] fn parse_games_flat( py: Python<'_>, pgn_chunked_array: PyChunkedArray, num_threads: Option, + chunk_multiplier: Option, ) -> PyResult { let num_threads = num_threads.unwrap_or_else(num_cpus::get); + // Default multiplier of 1 means exactly num_threads chunks (one per thread). + // Higher values (e.g., 4) create more chunks for better load balancing + // at the cost of more buffers to merge. + let chunk_multiplier = chunk_multiplier.unwrap_or(1); // Extract PGN strings from Arrow chunks let mut num_elements = 0; @@ -61,6 +69,9 @@ fn parse_games_flat( } let n_games = pgn_str_slices.len(); + if n_games == 0 { + return flat_buffers_to_parsed_games(py, FlatBuffers::default()); + } // Build thread pool let thread_pool = ThreadPoolBuilder::new() @@ -73,28 +84,34 @@ fn parse_games_flat( )) })?; - // Estimate capacity: ~70 moves per game - let games_per_thread = (n_games + num_threads - 1) / num_threads; + // Calculate chunk size for explicit chunking. + // num_chunks = num_threads * chunk_multiplier (e.g., 16 threads * 1 = 16 chunks) + let num_chunks = num_threads * chunk_multiplier; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; // ceiling division + let chunk_size = chunk_size.max(1); // ensure at least 1 game per chunk + + // Estimate capacity per chunk + let games_per_chunk = chunk_size; let moves_per_game = 70; - // Parse in parallel using fold_with for thread-local buffer accumulation. - // fold_with gives each worker thread its own FlatBuffers instance to accumulate into. - // This is more efficient than fold() which creates buffers per work-stealing chunk. - let thread_results: Vec = thread_pool.install(|| { + // Parse in parallel using par_chunks for explicit, fixed-size chunking. + // This creates exactly ceil(n_games / chunk_size) FlatBuffers instances, + // avoiding the allocation storm from Rayon's dynamic work-stealing. + let chunk_results: Vec = thread_pool.install(|| { pgn_str_slices - .par_iter() - .fold_with( - FlatBuffers::with_capacity(games_per_thread, moves_per_game), - |mut buffers, &pgn| { + .par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = FlatBuffers::with_capacity(games_per_chunk, moves_per_game); + for &pgn in chunk { let _ = parse_game_to_flat(pgn, &mut buffers); - buffers - }, - ) + } + buffers + }) .collect() }); - // Merge all thread-local buffers with pre-allocation (avoids repeated reallocations) - let combined_buffers = FlatBuffers::merge_all(thread_results); + // Merge all chunk buffers with pre-allocation (avoids repeated reallocations) + let combined_buffers = FlatBuffers::merge_all(chunk_results); // Convert FlatBuffers to ParsedGames with NumPy arrays flat_buffers_to_parsed_games(py, combined_buffers) @@ -180,9 +197,10 @@ fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResul pub fn parse_single_game_native( pgn: &str, store_legal_moves: bool, + store_board_states: bool, ) -> Result { let mut reader = Reader::new(Cursor::new(pgn)); - let mut extractor = MoveExtractor::new(store_legal_moves); + let mut extractor = MoveExtractor::new(store_legal_moves, store_board_states); match reader.read_game(&mut extractor) { Ok(Some(_)) => Ok(extractor), Ok(None) => Err("No game found in PGN".to_string()), @@ -194,6 +212,7 @@ pub fn parse_multiple_games_native( pgns: &Vec, num_threads: Option, store_legal_moves: bool, + store_board_states: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); @@ -205,7 +224,7 @@ pub fn parse_multiple_games_native( thread_pool.install(|| { pgns.par_iter() - .map(|pgn| parse_single_game_native(pgn, store_legal_moves)) + .map(|pgn| parse_single_game_native(pgn, store_legal_moves, store_board_states)) .collect() }) } @@ -214,6 +233,7 @@ fn _parse_game_moves_from_arrow_chunks_native( pgn_chunked_array: &PyChunkedArray, num_threads: Option, store_legal_moves: bool, + store_board_states: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); let thread_pool = ThreadPoolBuilder::new() @@ -250,7 +270,7 @@ fn _parse_game_moves_from_arrow_chunks_native( thread_pool.install(|| { pgn_str_slices .par_iter() - .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves)) + .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves, store_board_states)) .collect::, String>>() }) } @@ -259,34 +279,45 @@ fn _parse_game_moves_from_arrow_chunks_native( // TODO check if I can call py.allow_threads and release GIL // see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ #[pyfunction] -#[pyo3(signature = (pgn, store_legal_moves = false))] +#[pyo3(signature = (pgn, store_legal_moves = false, store_board_states = false))] /// Parses a single PGN game string. -fn parse_game(pgn: &str, store_legal_moves: bool) -> PyResult { - parse_single_game_native(pgn, store_legal_moves) +fn parse_game( + pgn: &str, + store_legal_moves: bool, + store_board_states: bool, +) -> PyResult { + parse_single_game_native(pgn, store_legal_moves, store_board_states) .map_err(pyo3::exceptions::PyValueError::new_err) } /// In parallel, parse a set of games #[pyfunction] -#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false))] +#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false, store_board_states=false))] fn parse_games( pgns: Vec, num_threads: Option, store_legal_moves: bool, + store_board_states: bool, ) -> PyResult> { - parse_multiple_games_native(&pgns, num_threads, store_legal_moves) + parse_multiple_games_native(&pgns, num_threads, store_legal_moves, store_board_states) .map_err(pyo3::exceptions::PyValueError::new_err) } #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false))] +#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false, store_board_states=false))] fn parse_game_moves_arrow_chunked_array( pgn_chunked_array: PyChunkedArray, num_threads: Option, store_legal_moves: bool, + store_board_states: bool, ) -> PyResult> { - _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads, store_legal_moves) - .map_err(pyo3::exceptions::PyValueError::new_err) + _parse_game_moves_from_arrow_chunks_native( + &pgn_chunked_array, + num_threads, + store_legal_moves, + store_board_states, + ) + .map_err(pyo3::exceptions::PyValueError::new_err) } /// Parser for chess PGN notation @@ -312,7 +343,7 @@ mod pyucimove_tests { #[test] fn test_parse_game_without_headers() { let pgn = "1. Nf3 d5 2. e4 c5 3. exd5 e5 4. dxe6 0-1"; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert_eq!(extractor.moves.len(), 7); @@ -325,7 +356,7 @@ mod pyucimove_tests { let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] 3. Bb5 a6 4. Ba4 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!(extractor.valid_moves, "Moves should be valid"); @@ -339,7 +370,7 @@ mod pyucimove_tests { [FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] 1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!( @@ -356,7 +387,7 @@ mod pyucimove_tests { [FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] 1. g3 d5 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!( @@ -371,7 +402,7 @@ mod pyucimove_tests { let pgn = r#"[FEN "invalid fen string"] 1. e4 e5 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!( @@ -386,7 +417,7 @@ mod pyucimove_tests { let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] 3. Bb5 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!( @@ -403,7 +434,7 @@ mod pyucimove_tests { [FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] 3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; - let result = parse_single_game_native(pgn, false); + let result = parse_single_game_native(pgn, false, false); assert!(result.is_ok()); let extractor = result.unwrap(); assert!( @@ -412,4 +443,36 @@ mod pyucimove_tests { ); assert_eq!(extractor.moves.len(), 5); // a6, Ba4, Nf6, O-O, Be7 } + + #[test] + fn test_parse_game_with_board_states() { + // Test that board states are populated when enabled + let pgn = "1. e4 e5 2. Nf3 Nc6 1-0"; + let result = parse_single_game_native(pgn, false, true); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert_eq!(extractor.moves.len(), 4); + // 5 positions: initial + 4 moves + assert_eq!(extractor.board_states.len(), 5 * 64); + assert_eq!(extractor.en_passant_states.len(), 5); + assert_eq!(extractor.halfmove_clocks.len(), 5); + assert_eq!(extractor.turn_states.len(), 5); + assert_eq!(extractor.castling_states.len(), 5 * 4); + } + + #[test] + fn test_parse_game_without_board_states() { + // Test that board states are NOT populated when disabled + let pgn = "1. e4 e5 2. Nf3 Nc6 1-0"; + let result = parse_single_game_native(pgn, false, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert_eq!(extractor.moves.len(), 4); + // Board state vectors should be empty + assert_eq!(extractor.board_states.len(), 0); + assert_eq!(extractor.en_passant_states.len(), 0); + assert_eq!(extractor.halfmove_clocks.len(), 0); + assert_eq!(extractor.turn_states.len(), 0); + assert_eq!(extractor.castling_states.len(), 0); + } } diff --git a/src/visitor.rs b/src/visitor.rs index bc1a703..50c4e13 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -15,6 +15,7 @@ pub struct MoveExtractor { pub moves: Vec, pub store_legal_moves: bool, + pub store_board_states: bool, pub flat_legal_moves: Vec, pub legal_moves_offsets: Vec, @@ -45,6 +46,7 @@ pub struct MoveExtractor { pub pos: Chess, // Board state tracking for flat output (not directly exposed to Python) + // Only populated if store_board_states is true pub board_states: Vec, // Flattened: 64 bytes per position pub en_passant_states: Vec, // Per position: -1 or file 0-7 pub halfmove_clocks: Vec, // Per position @@ -55,27 +57,29 @@ pub struct MoveExtractor { #[pymethods] impl MoveExtractor { #[new] - #[pyo3(signature = (store_legal_moves = false))] - pub fn new(store_legal_moves: bool) -> MoveExtractor { + #[pyo3(signature = (store_legal_moves = false, store_board_states = false))] + pub fn new(store_legal_moves: bool, store_board_states: bool) -> MoveExtractor { MoveExtractor { - moves: Vec::with_capacity(100), + moves: Vec::with_capacity(50), // Tuned: avg game ~40-50 moves store_legal_moves, - flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves - legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets + store_board_states, + flat_legal_moves: Vec::with_capacity(if store_legal_moves { 50 * 30 } else { 0 }), + legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 50 } else { 0 }), pos: Chess::default(), valid_moves: true, - comments: Vec::with_capacity(100), - evals: Vec::with_capacity(100), - clock_times: Vec::with_capacity(100), + comments: Vec::new(), // Lazy: will grow on demand + evals: Vec::new(), // Lazy: will grow on demand + clock_times: Vec::new(), // Lazy: will grow on demand outcome: None, headers: Vec::with_capacity(10), - castling_rights: Vec::with_capacity(100), + castling_rights: Vec::new(), // Lazy: will grow on demand position_status: None, - board_states: Vec::with_capacity(100 * 64), - en_passant_states: Vec::with_capacity(100), - halfmove_clocks: Vec::with_capacity(100), - turn_states: Vec::with_capacity(100), - castling_states: Vec::with_capacity(100 * 4), + // Only pre-allocate if storing board states + board_states: Vec::with_capacity(if store_board_states { 50 * 64 } else { 0 }), + en_passant_states: Vec::with_capacity(if store_board_states { 50 } else { 0 }), + halfmove_clocks: Vec::with_capacity(if store_board_states { 50 } else { 0 }), + turn_states: Vec::with_capacity(if store_board_states { 50 } else { 0 }), + castling_states: Vec::with_capacity(if store_board_states { 50 * 4 } else { 0 }), } } @@ -257,8 +261,10 @@ impl Visitor for MoveExtractor { if self.store_legal_moves { self.push_legal_moves(); } - // Record initial board state for flat output - self.push_board_state(); + // Record initial board state for flat output (only if enabled) + if self.store_board_states { + self.push_board_state(); + } ControlFlow::Continue(()) } @@ -276,8 +282,10 @@ impl Visitor for MoveExtractor { if self.store_legal_moves { self.push_legal_moves(); } - // Record board state after move for flat output - self.push_board_state(); + // Record board state after move for flat output (only if enabled) + if self.store_board_states { + self.push_board_state(); + } let uci_move_obj = UciMove::from_standard(m); match uci_move_obj { From 945a1045d67ff860564eb37d645c18fd3349a7a2 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 21:29:16 -0500 Subject: [PATCH 14/31] Restore preallocation --- src/visitor.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/visitor.rs b/src/visitor.rs index 50c4e13..f440a0c 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -60,26 +60,26 @@ impl MoveExtractor { #[pyo3(signature = (store_legal_moves = false, store_board_states = false))] pub fn new(store_legal_moves: bool, store_board_states: bool) -> MoveExtractor { MoveExtractor { - moves: Vec::with_capacity(50), // Tuned: avg game ~40-50 moves + moves: Vec::with_capacity(100), store_legal_moves, store_board_states, - flat_legal_moves: Vec::with_capacity(if store_legal_moves { 50 * 30 } else { 0 }), - legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 50 } else { 0 }), + flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), + legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), pos: Chess::default(), valid_moves: true, - comments: Vec::new(), // Lazy: will grow on demand - evals: Vec::new(), // Lazy: will grow on demand - clock_times: Vec::new(), // Lazy: will grow on demand + comments: Vec::with_capacity(100), + evals: Vec::with_capacity(100), + clock_times: Vec::with_capacity(100), outcome: None, headers: Vec::with_capacity(10), - castling_rights: Vec::new(), // Lazy: will grow on demand + castling_rights: Vec::with_capacity(100), position_status: None, // Only pre-allocate if storing board states - board_states: Vec::with_capacity(if store_board_states { 50 * 64 } else { 0 }), - en_passant_states: Vec::with_capacity(if store_board_states { 50 } else { 0 }), - halfmove_clocks: Vec::with_capacity(if store_board_states { 50 } else { 0 }), - turn_states: Vec::with_capacity(if store_board_states { 50 } else { 0 }), - castling_states: Vec::with_capacity(if store_board_states { 50 * 4 } else { 0 }), + board_states: Vec::with_capacity(if store_board_states { 100 * 64 } else { 0 }), + en_passant_states: Vec::with_capacity(if store_board_states { 100 } else { 0 }), + halfmove_clocks: Vec::with_capacity(if store_board_states { 100 } else { 0 }), + turn_states: Vec::with_capacity(if store_board_states { 100 } else { 0 }), + castling_states: Vec::with_capacity(if store_board_states { 100 * 4 } else { 0 }), } } From ad5721abc5a09924c60140fcf3f07c83ec1b196e Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 21:42:12 -0500 Subject: [PATCH 15/31] Fix performance regression by not initializing unused fields --- src/lib.rs | 22 ++++++------- src/visitor.rs | 86 ++++++++++++++++++++++++++++++++++---------------- 2 files changed, 69 insertions(+), 39 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index eb27872..05f7dd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -453,11 +453,15 @@ mod pyucimove_tests { let extractor = result.unwrap(); assert_eq!(extractor.moves.len(), 4); // 5 positions: initial + 4 moves - assert_eq!(extractor.board_states.len(), 5 * 64); - assert_eq!(extractor.en_passant_states.len(), 5); - assert_eq!(extractor.halfmove_clocks.len(), 5); - assert_eq!(extractor.turn_states.len(), 5); - assert_eq!(extractor.castling_states.len(), 5 * 4); + let data = extractor + .board_state_data + .as_ref() + .expect("board_state_data should be Some"); + assert_eq!(data.board_states.len(), 5 * 64); + assert_eq!(data.en_passant_states.len(), 5); + assert_eq!(data.halfmove_clocks.len(), 5); + assert_eq!(data.turn_states.len(), 5); + assert_eq!(data.castling_states.len(), 5 * 4); } #[test] @@ -468,11 +472,7 @@ mod pyucimove_tests { assert!(result.is_ok()); let extractor = result.unwrap(); assert_eq!(extractor.moves.len(), 4); - // Board state vectors should be empty - assert_eq!(extractor.board_states.len(), 0); - assert_eq!(extractor.en_passant_states.len(), 0); - assert_eq!(extractor.halfmove_clocks.len(), 0); - assert_eq!(extractor.turn_states.len(), 0); - assert_eq!(extractor.castling_states.len(), 0); + // Board state data should be None when disabled + assert!(extractor.board_state_data.is_none()); } } diff --git a/src/visitor.rs b/src/visitor.rs index f440a0c..353a1af 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -8,6 +8,37 @@ use pyo3::prelude::*; use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; use std::ops::ControlFlow; +/// Board state tracking for flat output. +/// Only allocated when store_board_states is true to avoid overhead in the common case. +#[derive(Default)] +pub struct BoardStateData { + pub board_states: Vec, // Flattened: 64 bytes per position + pub en_passant_states: Vec, // Per position: -1 or file 0-7 + pub halfmove_clocks: Vec, // Per position + pub turn_states: Vec, // Per position: true=white + pub castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] +} + +impl BoardStateData { + pub fn with_capacity(estimated_positions: usize) -> Self { + BoardStateData { + board_states: Vec::with_capacity(estimated_positions * 64), + en_passant_states: Vec::with_capacity(estimated_positions), + halfmove_clocks: Vec::with_capacity(estimated_positions), + turn_states: Vec::with_capacity(estimated_positions), + castling_states: Vec::with_capacity(estimated_positions * 4), + } + } + + pub fn clear(&mut self) { + self.board_states.clear(); + self.en_passant_states.clear(); + self.halfmove_clocks.clear(); + self.turn_states.clear(); + self.castling_states.clear(); + } +} + #[pyclass] /// A Visitor to extract SAN moves and comments from PGN movetext pub struct MoveExtractor { @@ -15,7 +46,6 @@ pub struct MoveExtractor { pub moves: Vec, pub store_legal_moves: bool, - pub store_board_states: bool, pub flat_legal_moves: Vec, pub legal_moves_offsets: Vec, @@ -46,12 +76,8 @@ pub struct MoveExtractor { pub pos: Chess, // Board state tracking for flat output (not directly exposed to Python) - // Only populated if store_board_states is true - pub board_states: Vec, // Flattened: 64 bytes per position - pub en_passant_states: Vec, // Per position: -1 or file 0-7 - pub halfmove_clocks: Vec, // Per position - pub turn_states: Vec, // Per position: true=white - pub castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] + // Only allocated if store_board_states is true to avoid overhead + pub board_state_data: Option>, } #[pymethods] @@ -62,7 +88,6 @@ impl MoveExtractor { MoveExtractor { moves: Vec::with_capacity(100), store_legal_moves, - store_board_states, flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), pos: Chess::default(), @@ -74,15 +99,20 @@ impl MoveExtractor { headers: Vec::with_capacity(10), castling_rights: Vec::with_capacity(100), position_status: None, - // Only pre-allocate if storing board states - board_states: Vec::with_capacity(if store_board_states { 100 * 64 } else { 0 }), - en_passant_states: Vec::with_capacity(if store_board_states { 100 } else { 0 }), - halfmove_clocks: Vec::with_capacity(if store_board_states { 100 } else { 0 }), - turn_states: Vec::with_capacity(if store_board_states { 100 } else { 0 }), - castling_states: Vec::with_capacity(if store_board_states { 100 * 4 } else { 0 }), + // Only allocate board state data when needed to avoid overhead + board_state_data: if store_board_states { + Some(Box::new(BoardStateData::with_capacity(100))) + } else { + None + }, } } + /// Check if board states are being stored. + pub fn stores_board_states(&self) -> bool { + self.board_state_data.is_some() + } + fn turn(&self) -> bool { match self.pos.turn() { Color::White => true, @@ -128,13 +158,15 @@ impl MoveExtractor { /// Record current board state to flat arrays for ParsedGames output. fn push_board_state(&mut self) { - self.board_states - .extend_from_slice(&serialize_board(&self.pos)); - self.en_passant_states.push(get_en_passant_file(&self.pos)); - self.halfmove_clocks.push(get_halfmove_clock(&self.pos)); - self.turn_states.push(get_turn(&self.pos)); - let castling = get_castling_rights(&self.pos); - self.castling_states.extend_from_slice(&castling); + if let Some(ref mut data) = self.board_state_data { + data.board_states + .extend_from_slice(&serialize_board(&self.pos)); + data.en_passant_states.push(get_en_passant_file(&self.pos)); + data.halfmove_clocks.push(get_halfmove_clock(&self.pos)); + data.turn_states.push(get_turn(&self.pos)); + let castling = get_castling_rights(&self.pos); + data.castling_states.extend_from_slice(&castling); + } } fn update_position_status(&mut self) { @@ -209,11 +241,9 @@ impl Visitor for MoveExtractor { self.evals.clear(); self.clock_times.clear(); self.castling_rights.clear(); - self.board_states.clear(); - self.en_passant_states.clear(); - self.halfmove_clocks.clear(); - self.turn_states.clear(); - self.castling_states.clear(); + if let Some(ref mut data) = self.board_state_data { + data.clear(); + } // Determine castling mode from Variant header (case-insensitive) let castling_mode = self @@ -262,7 +292,7 @@ impl Visitor for MoveExtractor { self.push_legal_moves(); } // Record initial board state for flat output (only if enabled) - if self.store_board_states { + if self.board_state_data.is_some() { self.push_board_state(); } ControlFlow::Continue(()) @@ -283,7 +313,7 @@ impl Visitor for MoveExtractor { self.push_legal_moves(); } // Record board state after move for flat output (only if enabled) - if self.store_board_states { + if self.board_state_data.is_some() { self.push_board_state(); } let uci_move_obj = UciMove::from_standard(m); From ad0da353c6307791bce0d89bee1ef3354f2dd9a5 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 21:46:47 -0500 Subject: [PATCH 16/31] Fix nit about a file being in multiple targets --- Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4cd7471..7ded532 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,9 +29,5 @@ criterion = "0.8" harness = false name = "parquet_bench" -[[bin]] -name = "parquet_bench" -path = "benches/parquet_bench.rs" - [profile.bench] debug = true From 31701d2fa80ed9d3caa347bde33d9353ef57b448 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 22:22:32 -0500 Subject: [PATCH 17/31] Add perf comment, adjust bench timing --- benches/parquet_bench.rs | 37 ++++++++++++++++++++++++++++--------- src/board_serialization.rs | 1 + 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 8907c9a..91ec800 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -107,16 +107,23 @@ pub fn bench_arrow_api(store_board_states: bool) { }) .expect("Parsing failed"); - let duration = start.elapsed(); + let parsing_duration = start.elapsed(); // Report results let total_moves: usize = extractors.iter().map(|e| e.moves.len()).sum(); - println!("Parsing time: {:?}", duration); - println!( - "Parsed {} games, {} total moves.", - extractors.len(), - total_moves - ); + let num_games = extractors.len(); + println!("Parsing time: {:?}", parsing_duration); + println!("Parsed {} games, {} total moves.", num_games, total_moves); + + // Explicitly drop to measure cleanup time + // This is the cost Python will pay when the list goes out of scope + let drop_start = Instant::now(); + drop(extractors); + let drop_duration = drop_start.elapsed(); + + let total_duration = start.elapsed(); + println!("Cleanup time (drop): {:?}", drop_duration); + println!("Total time (parsing + cleanup): {:?}", total_duration); } /// Benchmark the Flat API workflow. @@ -180,13 +187,25 @@ pub fn bench_flat_api() { // Step 5: Merge all chunk buffers (mirrors parse_games_flat) let combined_buffers = FlatBuffers::merge_all(chunk_results); - let duration_total = start.elapsed(); - println!("Total time (including merge): {:?}", duration_total); + let duration_with_merge = start.elapsed(); + println!("Total time (parsing + merge): {:?}", duration_with_merge); println!( "Parsed {} games, {} total positions.", combined_buffers.num_games(), combined_buffers.total_positions() ); + + // Measure cleanup time for fair comparison with Arrow API + let drop_start = Instant::now(); + drop(combined_buffers); + let drop_duration = drop_start.elapsed(); + + let total_duration = start.elapsed(); + println!("Cleanup time (drop): {:?}", drop_duration); + println!( + "Total time (parsing + merge + cleanup): {:?}", + total_duration + ); } fn main() { diff --git a/src/board_serialization.rs b/src/board_serialization.rs index b8fd021..381fc88 100644 --- a/src/board_serialization.rs +++ b/src/board_serialization.rs @@ -9,6 +9,7 @@ use shakmaty::{Chess, Color, EnPassantMode, Position, Role, Square}; +/// TODO this is a bottleneck for the multithreaded part of the parser /// Serialize board position to 64-byte array. /// Index mapping: square index (a1=0, h8=63) -> piece value (0-12) pub fn serialize_board(pos: &Chess) -> [u8; 64] { From 2e41ef594f6c116c64d5150113cb24a0eeb92e0f Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 22:38:33 -0500 Subject: [PATCH 18/31] Optimize board serialization --- src/board_serialization.rs | 62 +++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/src/board_serialization.rs b/src/board_serialization.rs index 381fc88..4ca411d 100644 --- a/src/board_serialization.rs +++ b/src/board_serialization.rs @@ -7,29 +7,61 @@ //! Note: This differs from some Python code that uses [7 - rank, file] indexing. //! The Python wrapper can transpose if needed. -use shakmaty::{Chess, Color, EnPassantMode, Position, Role, Square}; +use shakmaty::{Chess, Color, EnPassantMode, Position, Square}; -/// TODO this is a bottleneck for the multithreaded part of the parser /// Serialize board position to 64-byte array. /// Index mapping: square index (a1=0, h8=63) -> piece value (0-12) +/// +/// Optimized to iterate by piece type using bitboards directly, +/// avoiding expensive piece_at() lookups for all 64 squares. pub fn serialize_board(pos: &Chess) -> [u8; 64] { let mut board = [0u8; 64]; let b = pos.board(); - for sq in Square::ALL { - if let Some(piece) = b.piece_at(sq) { - let piece_val = match piece.role { - Role::Pawn => 1, - Role::Knight => 2, - Role::Bishop => 3, - Role::Rook => 4, - Role::Queen => 5, - Role::King => 6, - }; - let color_offset = if piece.color == Color::White { 0 } else { 6 }; - board[sq as usize] = piece_val + color_offset; - } + // Get color bitboards once + let white = b.white(); + let black = b.black(); + + // White pieces (value 1-6) + for sq in b.pawns() & white { + board[sq as usize] = 1; + } + for sq in b.knights() & white { + board[sq as usize] = 2; + } + for sq in b.bishops() & white { + board[sq as usize] = 3; } + for sq in b.rooks() & white { + board[sq as usize] = 4; + } + for sq in b.queens() & white { + board[sq as usize] = 5; + } + for sq in b.kings() & white { + board[sq as usize] = 6; + } + + // Black pieces (value 7-12) + for sq in b.pawns() & black { + board[sq as usize] = 7; + } + for sq in b.knights() & black { + board[sq as usize] = 8; + } + for sq in b.bishops() & black { + board[sq as usize] = 9; + } + for sq in b.rooks() & black { + board[sq as usize] = 10; + } + for sq in b.queens() & black { + board[sq as usize] = 11; + } + for sq in b.kings() & black { + board[sq as usize] = 12; + } + board } From 16b4b4f26623873c89de2bfedb43c1a8fc01917f Mon Sep 17 00:00:00 2001 From: vladkvit Date: Thu, 5 Feb 2026 22:40:22 -0500 Subject: [PATCH 19/31] Refactor optimized board serialization --- src/board_serialization.rs | 52 +++++++------------------------------- 1 file changed, 9 insertions(+), 43 deletions(-) diff --git a/src/board_serialization.rs b/src/board_serialization.rs index 4ca411d..53146b6 100644 --- a/src/board_serialization.rs +++ b/src/board_serialization.rs @@ -7,7 +7,7 @@ //! Note: This differs from some Python code that uses [7 - rank, file] indexing. //! The Python wrapper can transpose if needed. -use shakmaty::{Chess, Color, EnPassantMode, Position, Square}; +use shakmaty::{Chess, Color, EnPassantMode, Position, Role, Square}; /// Serialize board position to 64-byte array. /// Index mapping: square index (a1=0, h8=63) -> piece value (0-12) @@ -18,48 +18,14 @@ pub fn serialize_board(pos: &Chess) -> [u8; 64] { let mut board = [0u8; 64]; let b = pos.board(); - // Get color bitboards once - let white = b.white(); - let black = b.black(); - - // White pieces (value 1-6) - for sq in b.pawns() & white { - board[sq as usize] = 1; - } - for sq in b.knights() & white { - board[sq as usize] = 2; - } - for sq in b.bishops() & white { - board[sq as usize] = 3; - } - for sq in b.rooks() & white { - board[sq as usize] = 4; - } - for sq in b.queens() & white { - board[sq as usize] = 5; - } - for sq in b.kings() & white { - board[sq as usize] = 6; - } - - // Black pieces (value 7-12) - for sq in b.pawns() & black { - board[sq as usize] = 7; - } - for sq in b.knights() & black { - board[sq as usize] = 8; - } - for sq in b.bishops() & black { - board[sq as usize] = 9; - } - for sq in b.rooks() & black { - board[sq as usize] = 10; - } - for sq in b.queens() & black { - board[sq as usize] = 11; - } - for sq in b.kings() & black { - board[sq as usize] = 12; + for role in Role::ALL { + let piece_val = role as u8; // Role enum is 1-6 + for sq in b.by_role(role) & b.white() { + board[sq as usize] = piece_val; + } + for sq in b.by_role(role) & b.black() { + board[sq as usize] = piece_val + 6; + } } board From b7fc06e96c182ec15f0e9f9afb63840bc1355bb7 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 08:41:11 -0500 Subject: [PATCH 20/31] Linter rearranged imports --- benches/parquet_bench.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 91ec800..f3bc8e5 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -8,13 +8,13 @@ use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; -use rayon::prelude::*; use rayon::ThreadPoolBuilder; +use rayon::prelude::*; use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::{parse_game_to_flat, parse_single_game_native, FlatBuffers}; +use rust_pgn_reader_python_binding::{FlatBuffers, parse_game_to_flat, parse_single_game_native}; const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; From 62fb05bd659a10de72d5e076da2212fe2e126d6c Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 08:44:03 -0500 Subject: [PATCH 21/31] Parallel vec merge --- src/flat_visitor.rs | 119 +++++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 56 deletions(-) diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 0d14027..7f177bb 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -9,6 +9,7 @@ use crate::board_serialization::{ }; use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; use pgn_reader::{Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; +use rayon::prelude::*; use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; use std::collections::HashMap; use std::ops::ControlFlow; @@ -114,10 +115,10 @@ impl FlatBuffers { self.headers.extend(other.headers); } - /// Merge multiple FlatBuffers efficiently by pre-allocating total capacity. + /// Merge multiple FlatBuffers efficiently using rayon's parallel collect. /// - /// This avoids the repeated reallocations that occur when calling `merge` - /// in a loop starting from an empty buffer. + /// This uses rayon's `into_par_iter().flat_map().collect()` pattern which + /// safely handles parallel collection without unsafe code. pub fn merge_all(buffers: Vec) -> FlatBuffers { if buffers.is_empty() { return FlatBuffers::default(); @@ -126,63 +127,69 @@ impl FlatBuffers { return buffers.into_iter().next().unwrap(); } - // Calculate total sizes - let total_games: usize = buffers.iter().map(|b| b.headers.len()).sum(); - let total_positions: usize = buffers.iter().map(|b| b.boards.len() / 64).sum(); - let total_moves: usize = buffers.iter().map(|b| b.from_squares.len()).sum(); + // Use rayon's parallel flat_map and collect for all fields. + // This is safe and well-tested, using rayon's internal mechanisms + // for efficient parallel collection. + + // For Copy types, we can use flat_map with copied() + macro_rules! par_flatten { + ($field:ident) => { + buffers + .par_iter() + .flat_map(|buf| buf.$field.par_iter().copied()) + .collect() + }; + } - // Pre-allocate with exact capacity - let mut combined = FlatBuffers { - // Board state arrays - boards: Vec::with_capacity(total_positions * 64), - castling: Vec::with_capacity(total_positions * 4), - en_passant: Vec::with_capacity(total_positions), - halfmove_clock: Vec::with_capacity(total_positions), - turn: Vec::with_capacity(total_positions), + // Board state arrays + let boards: Vec = par_flatten!(boards); + let castling: Vec = par_flatten!(castling); + let en_passant: Vec = par_flatten!(en_passant); + let halfmove_clock: Vec = par_flatten!(halfmove_clock); + let turn: Vec = par_flatten!(turn); - // Move arrays - from_squares: Vec::with_capacity(total_moves), - to_squares: Vec::with_capacity(total_moves), - promotions: Vec::with_capacity(total_moves), - clocks: Vec::with_capacity(total_moves), - evals: Vec::with_capacity(total_moves), + // Move arrays + let from_squares: Vec = par_flatten!(from_squares); + let to_squares: Vec = par_flatten!(to_squares); + let promotions: Vec = par_flatten!(promotions); + let clocks: Vec = par_flatten!(clocks); + let evals: Vec = par_flatten!(evals); - // Per-game data - move_counts: Vec::with_capacity(total_games), - position_counts: Vec::with_capacity(total_games), - is_checkmate: Vec::with_capacity(total_games), - is_stalemate: Vec::with_capacity(total_games), - is_insufficient: Vec::with_capacity(total_games * 2), - legal_move_count: Vec::with_capacity(total_games), - valid: Vec::with_capacity(total_games), - headers: Vec::with_capacity(total_games), - }; - - // Now merge - no reallocations will occur - for buf in buffers { - combined.boards.extend(buf.boards); - combined.castling.extend(buf.castling); - combined.en_passant.extend(buf.en_passant); - combined.halfmove_clock.extend(buf.halfmove_clock); - combined.turn.extend(buf.turn); - - combined.from_squares.extend(buf.from_squares); - combined.to_squares.extend(buf.to_squares); - combined.promotions.extend(buf.promotions); - combined.clocks.extend(buf.clocks); - combined.evals.extend(buf.evals); - - combined.move_counts.extend(buf.move_counts); - combined.position_counts.extend(buf.position_counts); - combined.is_checkmate.extend(buf.is_checkmate); - combined.is_stalemate.extend(buf.is_stalemate); - combined.is_insufficient.extend(buf.is_insufficient); - combined.legal_move_count.extend(buf.legal_move_count); - combined.valid.extend(buf.valid); - combined.headers.extend(buf.headers); - } + // Per-game data + let move_counts: Vec = par_flatten!(move_counts); + let position_counts: Vec = par_flatten!(position_counts); + let is_checkmate: Vec = par_flatten!(is_checkmate); + let is_stalemate: Vec = par_flatten!(is_stalemate); + let is_insufficient: Vec = par_flatten!(is_insufficient); + let legal_move_count: Vec = par_flatten!(legal_move_count); + let valid: Vec = par_flatten!(valid); + + // Headers require ownership transfer (non-Copy type) + let headers: Vec> = buffers + .into_par_iter() + .flat_map(|buf| buf.headers) + .collect(); - combined + FlatBuffers { + boards, + castling, + en_passant, + halfmove_clock, + turn, + from_squares, + to_squares, + promotions, + clocks, + evals, + move_counts, + position_counts, + is_checkmate, + is_stalemate, + is_insufficient, + legal_move_count, + valid, + headers, + } } /// Number of games in this buffer. From 293c08e16f0ea698ad1f139d808a4b47cc3c9e0d Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 08:58:02 -0500 Subject: [PATCH 22/31] Revert "Parallel vec merge" This reverts commit 62fb05bd659a10de72d5e076da2212fe2e126d6c. --- src/flat_visitor.rs | 119 +++++++++++++++++++++----------------------- 1 file changed, 56 insertions(+), 63 deletions(-) diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 7f177bb..0d14027 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -9,7 +9,6 @@ use crate::board_serialization::{ }; use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; use pgn_reader::{Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; -use rayon::prelude::*; use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; use std::collections::HashMap; use std::ops::ControlFlow; @@ -115,10 +114,10 @@ impl FlatBuffers { self.headers.extend(other.headers); } - /// Merge multiple FlatBuffers efficiently using rayon's parallel collect. + /// Merge multiple FlatBuffers efficiently by pre-allocating total capacity. /// - /// This uses rayon's `into_par_iter().flat_map().collect()` pattern which - /// safely handles parallel collection without unsafe code. + /// This avoids the repeated reallocations that occur when calling `merge` + /// in a loop starting from an empty buffer. pub fn merge_all(buffers: Vec) -> FlatBuffers { if buffers.is_empty() { return FlatBuffers::default(); @@ -127,69 +126,63 @@ impl FlatBuffers { return buffers.into_iter().next().unwrap(); } - // Use rayon's parallel flat_map and collect for all fields. - // This is safe and well-tested, using rayon's internal mechanisms - // for efficient parallel collection. - - // For Copy types, we can use flat_map with copied() - macro_rules! par_flatten { - ($field:ident) => { - buffers - .par_iter() - .flat_map(|buf| buf.$field.par_iter().copied()) - .collect() - }; - } - - // Board state arrays - let boards: Vec = par_flatten!(boards); - let castling: Vec = par_flatten!(castling); - let en_passant: Vec = par_flatten!(en_passant); - let halfmove_clock: Vec = par_flatten!(halfmove_clock); - let turn: Vec = par_flatten!(turn); + // Calculate total sizes + let total_games: usize = buffers.iter().map(|b| b.headers.len()).sum(); + let total_positions: usize = buffers.iter().map(|b| b.boards.len() / 64).sum(); + let total_moves: usize = buffers.iter().map(|b| b.from_squares.len()).sum(); - // Move arrays - let from_squares: Vec = par_flatten!(from_squares); - let to_squares: Vec = par_flatten!(to_squares); - let promotions: Vec = par_flatten!(promotions); - let clocks: Vec = par_flatten!(clocks); - let evals: Vec = par_flatten!(evals); + // Pre-allocate with exact capacity + let mut combined = FlatBuffers { + // Board state arrays + boards: Vec::with_capacity(total_positions * 64), + castling: Vec::with_capacity(total_positions * 4), + en_passant: Vec::with_capacity(total_positions), + halfmove_clock: Vec::with_capacity(total_positions), + turn: Vec::with_capacity(total_positions), - // Per-game data - let move_counts: Vec = par_flatten!(move_counts); - let position_counts: Vec = par_flatten!(position_counts); - let is_checkmate: Vec = par_flatten!(is_checkmate); - let is_stalemate: Vec = par_flatten!(is_stalemate); - let is_insufficient: Vec = par_flatten!(is_insufficient); - let legal_move_count: Vec = par_flatten!(legal_move_count); - let valid: Vec = par_flatten!(valid); - - // Headers require ownership transfer (non-Copy type) - let headers: Vec> = buffers - .into_par_iter() - .flat_map(|buf| buf.headers) - .collect(); + // Move arrays + from_squares: Vec::with_capacity(total_moves), + to_squares: Vec::with_capacity(total_moves), + promotions: Vec::with_capacity(total_moves), + clocks: Vec::with_capacity(total_moves), + evals: Vec::with_capacity(total_moves), - FlatBuffers { - boards, - castling, - en_passant, - halfmove_clock, - turn, - from_squares, - to_squares, - promotions, - clocks, - evals, - move_counts, - position_counts, - is_checkmate, - is_stalemate, - is_insufficient, - legal_move_count, - valid, - headers, + // Per-game data + move_counts: Vec::with_capacity(total_games), + position_counts: Vec::with_capacity(total_games), + is_checkmate: Vec::with_capacity(total_games), + is_stalemate: Vec::with_capacity(total_games), + is_insufficient: Vec::with_capacity(total_games * 2), + legal_move_count: Vec::with_capacity(total_games), + valid: Vec::with_capacity(total_games), + headers: Vec::with_capacity(total_games), + }; + + // Now merge - no reallocations will occur + for buf in buffers { + combined.boards.extend(buf.boards); + combined.castling.extend(buf.castling); + combined.en_passant.extend(buf.en_passant); + combined.halfmove_clock.extend(buf.halfmove_clock); + combined.turn.extend(buf.turn); + + combined.from_squares.extend(buf.from_squares); + combined.to_squares.extend(buf.to_squares); + combined.promotions.extend(buf.promotions); + combined.clocks.extend(buf.clocks); + combined.evals.extend(buf.evals); + + combined.move_counts.extend(buf.move_counts); + combined.position_counts.extend(buf.position_counts); + combined.is_checkmate.extend(buf.is_checkmate); + combined.is_stalemate.extend(buf.is_stalemate); + combined.is_insufficient.extend(buf.is_insufficient); + combined.legal_move_count.extend(buf.legal_move_count); + combined.valid.extend(buf.valid); + combined.headers.extend(buf.headers); } + + combined } /// Number of games in this buffer. From 5c7898a95c4c750078a454d91580ebe382af5347 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 09:00:45 -0500 Subject: [PATCH 23/31] Tweak benchmark comments --- benches/parquet_bench.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index f3bc8e5..2fbfbbf 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -1,10 +1,7 @@ //! Benchmark for PGN parsing APIs, designed to mirror the Python workflow. //! -//! This benchmark emulates the call graph of: -//! - `parse_game_moves_arrow_chunked_array()` (Arrow API → Vec) -//! - `parse_games_flat()` (Flat API → FlatBuffers with NumPy-like arrays) -//! -//! Both use zero-copy &str slices from Arrow StringArrays, matching Python's behavior. +//! `cargo bench --bench parquet_bench` +//! `samply record --rate 10000 cargo bench --bench parquet_bench` use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; From b2f9380d3c1e51c5348d581d66221438dd0445b1 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 12:02:34 -0500 Subject: [PATCH 24/31] Rework to not require a vector merge --- benches/parquet_bench.rs | 31 +- rust_pgn_reader_python_binding.pyi | 156 ++++------ src/bench_parse_games_flat.py | 55 ++-- src/example_parse_games_flat.py | 35 ++- src/flat_visitor.rs | 122 +------- src/lib.rs | 119 +++++++- src/python_bindings.rs | 475 ++++++++++++++++++++--------- src/test.py | 88 +++--- 8 files changed, 612 insertions(+), 469 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 2fbfbbf..3e5347e 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -5,13 +5,13 @@ use arrow::array::{Array, StringArray}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; -use rayon::ThreadPoolBuilder; use rayon::prelude::*; +use rayon::ThreadPoolBuilder; use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::{FlatBuffers, parse_game_to_flat, parse_single_game_native}; +use rust_pgn_reader_python_binding::{parse_game_to_flat, parse_single_game_native, FlatBuffers}; const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; @@ -129,7 +129,8 @@ pub fn bench_arrow_api(store_board_states: bool) { /// 1. Read parquet to Arrow arrays /// 2. Extract &str slices from StringArray /// 3. Parse in parallel with explicit chunking (par_chunks) → fixed number of FlatBuffers -/// 4. Merge all FlatBuffers into one +/// +/// No merge step — the chunked architecture keeps per-thread buffers as-is. pub fn bench_flat_api() { // Step 1: Read parquet to Arrow StringArrays let arrays = read_parquet_to_string_arrays(FILE_PATH); @@ -179,30 +180,30 @@ pub fn bench_flat_api() { let duration_parallel = start.elapsed(); println!("Parallel parsing time: {:?}", duration_parallel); - println!("Created {} FlatBuffers to merge", chunk_results.len()); + println!( + "Created {} FlatBuffers chunks (no merge needed)", + chunk_results.len() + ); - // Step 5: Merge all chunk buffers (mirrors parse_games_flat) - let combined_buffers = FlatBuffers::merge_all(chunk_results); + // Compute totals from chunks + let total_games: usize = chunk_results.iter().map(|b| b.num_games()).sum(); + let total_positions: usize = chunk_results.iter().map(|b| b.total_positions()).sum(); - let duration_with_merge = start.elapsed(); - println!("Total time (parsing + merge): {:?}", duration_with_merge); + let duration_total = start.elapsed(); + println!("Total time (parsing, no merge): {:?}", duration_total); println!( "Parsed {} games, {} total positions.", - combined_buffers.num_games(), - combined_buffers.total_positions() + total_games, total_positions ); // Measure cleanup time for fair comparison with Arrow API let drop_start = Instant::now(); - drop(combined_buffers); + drop(chunk_results); let drop_duration = drop_start.elapsed(); let total_duration = start.elapsed(); println!("Cleanup time (drop): {:?}", drop_duration); - println!( - "Total time (parsing + merge + cleanup): {:?}", - total_duration - ); + println!("Total time (parsing + cleanup): {:?}", total_duration); } fn main() { diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index d75d603..0441a0d 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -178,126 +178,73 @@ class ParsedGamesIter: def __iter__(self) -> "ParsedGamesIter": ... def __next__(self) -> PyGameView: ... -class ParsedGames: - """Flat array container for parsed chess games, optimized for ML training. - - Indexing: - - N_games: Number of games - - N_moves: Total moves across all games - - N_positions: Total board positions recorded +class PyChunkView: + """View into a single chunk's raw numpy arrays. - Board layout: - Boards use square indexing: a1=0, b1=1, ..., h8=63 - Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk + Access via ``parsed_games.chunks[i]``. Each chunk corresponds to one + parsing thread's output. Use this for advanced access patterns like + manual concatenation or custom batching. """ - # === Board state arrays (N_positions) === - + @property + def num_games(self) -> int: ... + @property + def num_moves(self) -> int: ... + @property + def num_positions(self) -> int: ... @property def boards(self) -> NDArray[np.uint8]: """Board positions, shape (N_positions, 8, 8), dtype uint8.""" ... - @property def castling(self) -> NDArray[np.bool_]: """Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool.""" ... - @property - def en_passant(self) -> NDArray[np.int8]: - """En passant file (-1 if none), shape (N_positions,), dtype int8.""" - ... - + def en_passant(self) -> NDArray[np.int8]: ... @property - def halfmove_clock(self) -> NDArray[np.uint8]: - """Halfmove clock, shape (N_positions,), dtype uint8.""" - ... - + def halfmove_clock(self) -> NDArray[np.uint8]: ... @property - def turn(self) -> NDArray[np.bool_]: - """Side to move (True=white), shape (N_positions,), dtype bool.""" - ... - - # === Move arrays (N_moves) === - + def turn(self) -> NDArray[np.bool_]: ... @property - def from_squares(self) -> NDArray[np.uint8]: - """From squares, shape (N_moves,), dtype uint8.""" - ... - + def from_squares(self) -> NDArray[np.uint8]: ... @property - def to_squares(self) -> NDArray[np.uint8]: - """To squares, shape (N_moves,), dtype uint8.""" - ... - + def to_squares(self) -> NDArray[np.uint8]: ... @property - def promotions(self) -> NDArray[np.int8]: - """Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8.""" - ... - + def promotions(self) -> NDArray[np.int8]: ... @property - def clocks(self) -> NDArray[np.float32]: - """Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32.""" - ... - + def clocks(self) -> NDArray[np.float32]: ... @property - def evals(self) -> NDArray[np.float32]: - """Engine evals (NaN if missing), shape (N_moves,), dtype float32.""" - ... - - # === Offsets === - + def evals(self) -> NDArray[np.float32]: ... @property - def move_offsets(self) -> NDArray[np.uint32]: - """Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32. - - Game i's moves: move_offsets[i]..move_offsets[i+1] - """ - ... - + def move_offsets(self) -> NDArray[np.uint32]: ... @property - def position_offsets(self) -> NDArray[np.uint32]: - """Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32. - - Game i's positions: position_offsets[i]..position_offsets[i+1] - """ - ... - - # === Final position status (N_games) === - + def position_offsets(self) -> NDArray[np.uint32]: ... @property - def is_checkmate(self) -> NDArray[np.bool_]: - """Final position is checkmate, shape (N_games,), dtype bool.""" - ... - + def is_checkmate(self) -> NDArray[np.bool_]: ... @property - def is_stalemate(self) -> NDArray[np.bool_]: - """Final position is stalemate, shape (N_games,), dtype bool.""" - ... - + def is_stalemate(self) -> NDArray[np.bool_]: ... @property - def is_insufficient(self) -> NDArray[np.bool_]: - """Insufficient material (white, black), shape (N_games, 2), dtype bool.""" - ... - + def is_insufficient(self) -> NDArray[np.bool_]: ... @property - def legal_move_count(self) -> NDArray[np.uint16]: - """Legal move count in final position, shape (N_games,), dtype uint16.""" - ... - - # === Parse status (N_games) === - + def legal_move_count(self) -> NDArray[np.uint16]: ... @property - def valid(self) -> NDArray[np.bool_]: - """Whether game parsed successfully, shape (N_games,), dtype bool.""" - ... + def valid(self) -> NDArray[np.bool_]: ... + @property + def headers(self) -> List[Dict[str, str]]: ... + def __repr__(self) -> str: ... - # === Raw headers (N_games) === +class ParsedGames: + """Chunked container for parsed chess games, optimized for ML training. - @property - def headers(self) -> List[Dict[str, str]]: - """Raw PGN headers as list of dicts.""" - ... + Internally stores data in multiple chunks (one per parsing thread) to + avoid the cost of merging. Per-game access is O(log(num_chunks)) via + binary search on precomputed boundaries. + + Board layout: + Boards use square indexing: a1=0, b1=1, ..., h8=63 + Piece encoding: 0=empty, 1-6=white PNBRQK, 7-12=black pnbrqk + """ # === Computed properties === @@ -316,6 +263,22 @@ class ParsedGames: """Total number of board positions recorded.""" ... + @property + def num_chunks(self) -> int: + """Number of internal chunks.""" + ... + + # === Escape hatch: raw chunk access === + + @property + def chunks(self) -> List[PyChunkView]: + """Access raw per-chunk data. + + Each chunk corresponds to one parsing thread's output. Use this + for advanced access patterns like manual concatenation. + """ + ... + # === Sequence protocol === def __len__(self) -> int: @@ -339,12 +302,12 @@ class ParsedGames: # === Mapping utilities === def position_to_game(self, position_indices: npt.ArrayLike) -> NDArray[np.int64]: - """Map position indices to game indices. + """Map global position indices to game indices. Useful after shuffling/sampling positions to look up game metadata. Args: - position_indices: Array of indices into boards array. + position_indices: Array of indices into the global position space. Accepts any integer dtype; int64 is optimal (avoids conversion). Returns: @@ -353,10 +316,10 @@ class ParsedGames: ... def move_to_game(self, move_indices: npt.ArrayLike) -> NDArray[np.int64]: - """Map move indices to game indices. + """Map global move indices to game indices. Args: - move_indices: Array of indices into from_squares, to_squares, etc. + move_indices: Array of indices into the global move space. Accepts any integer dtype; int64 is optimal (avoids conversion). Returns: @@ -376,6 +339,7 @@ def parse_game_moves_arrow_chunked_array( def parse_games_flat( pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None, + chunk_multiplier: Optional[int] = None, ) -> ParsedGames: """Parse chess games from a PyArrow ChunkedArray into flat NumPy arrays. diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games_flat.py index e92c739..cf72ac7 100644 --- a/src/bench_parse_games_flat.py +++ b/src/bench_parse_games_flat.py @@ -114,7 +114,7 @@ def benchmark_parse_games_flat( "games_per_second": result.num_games / elapsed, "moves_per_second": result.num_moves / elapsed, "positions_per_second": result.num_positions / elapsed, - "valid_games": int(result.valid.sum()), + "valid_games": int(sum(chunk.valid.sum() for chunk in result.chunks)), "result": result, } @@ -159,18 +159,22 @@ def benchmark_data_access_flat(result) -> dict: """Benchmark data access patterns for parse_games_flat result.""" start = time.perf_counter() - # Simulate ML data loading: access all boards - _ = result.boards.sum() + # Simulate ML data loading: access all boards via chunks + for chunk in result.chunks: + _ = chunk.boards.sum() - # Access moves - _ = result.from_squares.sum() - _ = result.to_squares.sum() + # Access moves via chunks + for chunk in result.chunks: + _ = chunk.from_squares.sum() + _ = chunk.to_squares.sum() - # Random position access - indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) - _ = result.boards[indices] + # Per-game access pattern (iterate and access boards) + for i in range(min(1000, result.num_games)): + game = result[i] + _ = game.boards - # Position-to-game mapping + # Position-to-game mapping (still works globally) + indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) _ = result.position_to_game(indices) elapsed = time.perf_counter() - start @@ -323,24 +327,27 @@ def main(): print("=" * 60) flat_result = flat_results["result"] - flat_bytes = ( - flat_result.boards.nbytes - + flat_result.castling.nbytes - + flat_result.en_passant.nbytes - + flat_result.halfmove_clock.nbytes - + flat_result.turn.nbytes - + flat_result.from_squares.nbytes - + flat_result.to_squares.nbytes - + flat_result.promotions.nbytes - + flat_result.clocks.nbytes - + flat_result.evals.nbytes - + flat_result.move_offsets.nbytes - + flat_result.position_offsets.nbytes - ) + flat_bytes = 0 + for chunk in flat_result.chunks: + flat_bytes += ( + chunk.boards.nbytes + + chunk.castling.nbytes + + chunk.en_passant.nbytes + + chunk.halfmove_clock.nbytes + + chunk.turn.nbytes + + chunk.from_squares.nbytes + + chunk.to_squares.nbytes + + chunk.promotions.nbytes + + chunk.clocks.nbytes + + chunk.evals.nbytes + + chunk.move_offsets.nbytes + + chunk.position_offsets.nbytes + ) print(f"\nFlat arrays total: {flat_bytes / 1024 / 1024:.2f} MB") print(f"Bytes per position: {flat_bytes / flat_result.num_positions:.1f}") print(f"Bytes per move: {flat_bytes / flat_result.num_moves:.1f}") + print(f"Number of chunks: {flat_result.num_chunks}") print("\n" + "=" * 60) print("Benchmark complete!") diff --git a/src/example_parse_games_flat.py b/src/example_parse_games_flat.py index 61c5b36..f0848b1 100644 --- a/src/example_parse_games_flat.py +++ b/src/example_parse_games_flat.py @@ -66,22 +66,20 @@ def main(): print(f"Number of games: {result.num_games}") print(f"Total moves: {result.num_moves}") print(f"Total positions: {result.num_positions}") - print(f"Valid games: {result.valid.sum()} / {result.num_games}") + print(f"Number of chunks: {result.num_chunks}") - # === Array Shapes === - print(f"\n--- Array Shapes ---") - print(f"boards: {result.boards.shape} ({result.boards.dtype})") - print(f"castling: {result.castling.shape} ({result.castling.dtype})") - print(f"en_passant: {result.en_passant.shape} ({result.en_passant.dtype})") - print( - f"from_squares: {result.from_squares.shape} ({result.from_squares.dtype})" - ) - print(f"to_squares: {result.to_squares.shape} ({result.to_squares.dtype})") - print(f"promotions: {result.promotions.shape} ({result.promotions.dtype})") - print(f"clocks: {result.clocks.shape} ({result.clocks.dtype})") - print(f"evals: {result.evals.shape} ({result.evals.dtype})") - print(f"move_offsets: {result.move_offsets.shape}") - print(f"position_offsets: {result.position_offsets.shape}") + # === Chunk Details (escape hatch for raw array access) === + print(f"\n--- Chunk Details ---") + for i, chunk in enumerate(result.chunks): + print( + f" Chunk {i}: {chunk.num_games} games, " + f"{chunk.num_moves} moves, " + f"{chunk.num_positions} positions" + ) + print(f" boards: {chunk.boards.shape} ({chunk.boards.dtype})") + print( + f" from_squares: {chunk.from_squares.shape} ({chunk.from_squares.dtype})" + ) # === Iterate Over Games === print(f"\n--- Game Details ---") @@ -106,9 +104,10 @@ def main(): # === Direct Array Access for ML === print(f"\n--- ML-Ready Data Access ---") - # Get all board positions as a single tensor - all_boards = result.boards # Shape: (N_positions, 8, 8) - print(f"All boards tensor: {all_boards.shape}") + # Access boards via chunks (no single merged array) + # To concatenate all boards: np.concatenate([c.boards for c in result.chunks]) + chunk0_boards = result.chunks[0].boards + print(f"Chunk 0 boards: {chunk0_boards.shape}") # Get initial position of first game game0 = result[0] diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 0d14027..00f7026 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -83,108 +83,6 @@ impl FlatBuffers { } } - /// Merge another FlatBuffers into this one. - /// Used to combine thread-local buffers after parallel parsing. - /// - /// Note: For merging multiple buffers, prefer `merge_all` which pre-allocates - /// to avoid repeated reallocations. - pub fn merge(&mut self, other: FlatBuffers) { - // Board state arrays - self.boards.extend(other.boards); - self.castling.extend(other.castling); - self.en_passant.extend(other.en_passant); - self.halfmove_clock.extend(other.halfmove_clock); - self.turn.extend(other.turn); - - // Move arrays - self.from_squares.extend(other.from_squares); - self.to_squares.extend(other.to_squares); - self.promotions.extend(other.promotions); - self.clocks.extend(other.clocks); - self.evals.extend(other.evals); - - // Per-game data - self.move_counts.extend(other.move_counts); - self.position_counts.extend(other.position_counts); - self.is_checkmate.extend(other.is_checkmate); - self.is_stalemate.extend(other.is_stalemate); - self.is_insufficient.extend(other.is_insufficient); - self.legal_move_count.extend(other.legal_move_count); - self.valid.extend(other.valid); - self.headers.extend(other.headers); - } - - /// Merge multiple FlatBuffers efficiently by pre-allocating total capacity. - /// - /// This avoids the repeated reallocations that occur when calling `merge` - /// in a loop starting from an empty buffer. - pub fn merge_all(buffers: Vec) -> FlatBuffers { - if buffers.is_empty() { - return FlatBuffers::default(); - } - if buffers.len() == 1 { - return buffers.into_iter().next().unwrap(); - } - - // Calculate total sizes - let total_games: usize = buffers.iter().map(|b| b.headers.len()).sum(); - let total_positions: usize = buffers.iter().map(|b| b.boards.len() / 64).sum(); - let total_moves: usize = buffers.iter().map(|b| b.from_squares.len()).sum(); - - // Pre-allocate with exact capacity - let mut combined = FlatBuffers { - // Board state arrays - boards: Vec::with_capacity(total_positions * 64), - castling: Vec::with_capacity(total_positions * 4), - en_passant: Vec::with_capacity(total_positions), - halfmove_clock: Vec::with_capacity(total_positions), - turn: Vec::with_capacity(total_positions), - - // Move arrays - from_squares: Vec::with_capacity(total_moves), - to_squares: Vec::with_capacity(total_moves), - promotions: Vec::with_capacity(total_moves), - clocks: Vec::with_capacity(total_moves), - evals: Vec::with_capacity(total_moves), - - // Per-game data - move_counts: Vec::with_capacity(total_games), - position_counts: Vec::with_capacity(total_games), - is_checkmate: Vec::with_capacity(total_games), - is_stalemate: Vec::with_capacity(total_games), - is_insufficient: Vec::with_capacity(total_games * 2), - legal_move_count: Vec::with_capacity(total_games), - valid: Vec::with_capacity(total_games), - headers: Vec::with_capacity(total_games), - }; - - // Now merge - no reallocations will occur - for buf in buffers { - combined.boards.extend(buf.boards); - combined.castling.extend(buf.castling); - combined.en_passant.extend(buf.en_passant); - combined.halfmove_clock.extend(buf.halfmove_clock); - combined.turn.extend(buf.turn); - - combined.from_squares.extend(buf.from_squares); - combined.to_squares.extend(buf.to_squares); - combined.promotions.extend(buf.promotions); - combined.clocks.extend(buf.clocks); - combined.evals.extend(buf.evals); - - combined.move_counts.extend(buf.move_counts); - combined.position_counts.extend(buf.position_counts); - combined.is_checkmate.extend(buf.is_checkmate); - combined.is_stalemate.extend(buf.is_stalemate); - combined.is_insufficient.extend(buf.is_insufficient); - combined.legal_move_count.extend(buf.legal_move_count); - combined.valid.extend(buf.valid); - combined.headers.extend(buf.headers); - } - - combined - } - /// Number of games in this buffer. pub fn num_games(&self) -> usize { self.headers.len() @@ -550,7 +448,7 @@ mod tests { } #[test] - fn test_merge_buffers() { + fn test_multiple_games_in_one_buffer() { let pgn1 = r#"[Event "Game1"] [Result "1-0"] @@ -561,20 +459,14 @@ mod tests { 1. d4 d5 2. c4 0-1"#; - let mut buffers1 = FlatBuffers::with_capacity(1, 70); - let mut buffers2 = FlatBuffers::with_capacity(1, 70); - - parse_game_to_flat(pgn1, &mut buffers1).unwrap(); - parse_game_to_flat(pgn2, &mut buffers2).unwrap(); - - assert_eq!(buffers1.num_games(), 1); - assert_eq!(buffers2.num_games(), 1); + let mut buffers = FlatBuffers::with_capacity(2, 70); - buffers1.merge(buffers2); + parse_game_to_flat(pgn1, &mut buffers).unwrap(); + parse_game_to_flat(pgn2, &mut buffers).unwrap(); - assert_eq!(buffers1.num_games(), 2); - assert_eq!(buffers1.total_moves(), 5); // 2 + 3 moves - assert_eq!(buffers1.move_counts, vec![2, 3]); + assert_eq!(buffers.num_games(), 2); + assert_eq!(buffers.total_moves(), 5); // 2 + 3 moves + assert_eq!(buffers.move_counts, vec![2, 3]); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 05f7dd6..48961a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,17 +15,20 @@ mod python_bindings; mod visitor; pub use flat_visitor::{parse_game_to_flat, FlatBuffers}; -use python_bindings::{ParsedGames, ParsedGamesIter, PositionStatus, PyGameView, PyUciMove}; +use python_bindings::{ + ChunkData, ParsedGames, ParsedGamesIter, PositionStatus, PyChunkView, PyGameView, PyUciMove, +}; pub use visitor::MoveExtractor; -/// Parse games from Arrow chunked array into flat NumPy arrays. +/// Parse games from Arrow chunked array into a chunked ParsedGames container. /// /// This implementation uses explicit chunking with a fixed number of chunks /// (num_chunks = num_threads * chunk_multiplier) to avoid the allocation storm /// caused by Rayon's dynamic work-stealing with fold_with. /// -/// Each chunk gets exactly one FlatBuffers instance, drastically reducing -/// the number of allocations and making merge_all much faster. +/// Each chunk gets exactly one FlatBuffers instance. Instead of merging all +/// chunks into a single buffer (which was memory-bandwidth-bound), we keep +/// the per-thread buffers and provide virtual indexing across them. #[pyfunction] #[pyo3(signature = (pgn_chunked_array, num_threads=None, chunk_multiplier=None))] fn parse_games_flat( @@ -37,7 +40,7 @@ fn parse_games_flat( let num_threads = num_threads.unwrap_or_else(num_cpus::get); // Default multiplier of 1 means exactly num_threads chunks (one per thread). // Higher values (e.g., 4) create more chunks for better load balancing - // at the cost of more buffers to merge. + // at the cost of slightly more complex indexing. let chunk_multiplier = chunk_multiplier.unwrap_or(1); // Extract PGN strings from Arrow chunks @@ -70,7 +73,8 @@ fn parse_games_flat( let n_games = pgn_str_slices.len(); if n_games == 0 { - return flat_buffers_to_parsed_games(py, FlatBuffers::default()); + let empty_chunk = flat_buffers_to_chunk_data(py, FlatBuffers::default())?; + return build_parsed_games(py, vec![empty_chunk]); } // Build thread pool @@ -110,23 +114,25 @@ fn parse_games_flat( .collect() }); - // Merge all chunk buffers with pre-allocation (avoids repeated reallocations) - let combined_buffers = FlatBuffers::merge_all(chunk_results); + // Convert each FlatBuffers to ChunkData (numpy arrays) — no merge needed + let chunk_data_vec: Vec = chunk_results + .into_iter() + .map(|buf| flat_buffers_to_chunk_data(py, buf)) + .collect::>>()?; - // Convert FlatBuffers to ParsedGames with NumPy arrays - flat_buffers_to_parsed_games(py, combined_buffers) + build_parsed_games(py, chunk_data_vec) } -/// Convert FlatBuffers to ParsedGames with NumPy arrays. -fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResult { +/// Convert a single FlatBuffers into a ChunkData with NumPy arrays. +fn flat_buffers_to_chunk_data(py: Python<'_>, buffers: FlatBuffers) -> PyResult { let n_games = buffers.num_games(); let total_positions = buffers.total_positions(); + let total_moves = buffers.total_moves(); - // Compute offsets + // Compute local CSR offsets let move_offsets_vec = buffers.compute_move_offsets(); let position_offsets_vec = buffers.compute_position_offsets(); - // Convert to NumPy arrays // Boards: reshape from flat to (N_positions, 8, 8) let boards_array = PyArray1::from_vec(py, buffers.boards); let boards_reshaped = boards_array @@ -156,7 +162,6 @@ fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResul let is_checkmate_array = PyArray1::from_vec(py, buffers.is_checkmate); let is_stalemate_array = PyArray1::from_vec(py, buffers.is_stalemate); - // is_insufficient: reshape to (N_games, 2) let is_insufficient_array = PyArray1::from_vec(py, buffers.is_insufficient); let is_insufficient_reshaped = if n_games > 0 { is_insufficient_array @@ -171,7 +176,7 @@ fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResul let legal_move_count_array = PyArray1::from_vec(py, buffers.legal_move_count); let valid_array = PyArray1::from_vec(py, buffers.valid); - Ok(ParsedGames { + Ok(ChunkData { boards: boards_reshaped.unbind().into_any(), castling: castling_reshaped.unbind().into_any(), en_passant: en_passant_array.unbind().into_any(), @@ -190,6 +195,87 @@ fn flat_buffers_to_parsed_games(py: Python<'_>, buffers: FlatBuffers) -> PyResul legal_move_count: legal_move_count_array.unbind().into_any(), valid: valid_array.unbind().into_any(), headers: buffers.headers, + num_games: n_games, + num_moves: total_moves, + num_positions: total_positions, + }) +} + +/// Build a ParsedGames from a Vec of ChunkData. +/// +/// Computes prefix-sum boundary arrays and global CSR offsets for +/// position_to_game / move_to_game. +fn build_parsed_games(py: Python<'_>, chunks: Vec) -> PyResult { + let mut total_games: usize = 0; + let mut total_moves: usize = 0; + let mut total_positions: usize = 0; + + // Build boundary arrays (prefix sums) + let mut game_boundaries = Vec::with_capacity(chunks.len() + 1); + let mut move_boundaries = Vec::with_capacity(chunks.len() + 1); + let mut position_boundaries = Vec::with_capacity(chunks.len() + 1); + + game_boundaries.push(0); + move_boundaries.push(0); + position_boundaries.push(0); + + for chunk in &chunks { + total_games += chunk.num_games; + total_moves += chunk.num_moves; + total_positions += chunk.num_positions; + game_boundaries.push(total_games); + move_boundaries.push(total_moves); + position_boundaries.push(total_positions); + } + + // Build global CSR offsets for position_to_game / move_to_game. + // These are the per-chunk local offsets shifted by the chunk's base offset, + // concatenated into a single array. Length = total_games + 1. + let mut global_move_offsets_vec: Vec = Vec::with_capacity(total_games + 1); + let mut global_position_offsets_vec: Vec = Vec::with_capacity(total_games + 1); + + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let move_base = move_boundaries[chunk_idx] as u32; + let pos_base = position_boundaries[chunk_idx] as u32; + + // Read the chunk's local offsets + let local_move_offsets = chunk.move_offsets.bind(py); + let local_move_offsets: &Bound<'_, PyArray1> = local_move_offsets.cast()?; + let local_move_ro = local_move_offsets.readonly(); + let local_move_slice = local_move_ro.as_slice()?; + + let local_pos_offsets = chunk.position_offsets.bind(py); + let local_pos_offsets: &Bound<'_, PyArray1> = local_pos_offsets.cast()?; + let local_pos_ro = local_pos_offsets.readonly(); + let local_pos_slice = local_pos_ro.as_slice()?; + + // Append all but the last offset (which is the total for this chunk) + // The last chunk's final offset will be added after the loop. + for &offset in &local_move_slice[..local_move_slice.len() - 1] { + global_move_offsets_vec.push(move_base + offset); + } + for &offset in &local_pos_slice[..local_pos_slice.len() - 1] { + global_position_offsets_vec.push(pos_base + offset); + } + } + + // Final sentinel value + global_move_offsets_vec.push(total_moves as u32); + global_position_offsets_vec.push(total_positions as u32); + + let global_move_offsets = PyArray1::from_vec(py, global_move_offsets_vec); + let global_position_offsets = PyArray1::from_vec(py, global_position_offsets_vec); + + Ok(ParsedGames { + chunks, + game_boundaries, + move_boundaries, + position_boundaries, + total_games, + total_moves, + total_positions, + global_move_offsets: global_move_offsets.unbind().into_any(), + global_position_offsets: global_position_offsets.unbind().into_any(), }) } @@ -332,6 +418,7 @@ fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/src/python_bindings.rs b/src/python_bindings.rs index 82feb14..9840e68 100644 --- a/src/python_bindings.rs +++ b/src/python_bindings.rs @@ -1,6 +1,6 @@ use numpy::{PyArray1, PyArray2, PyArrayMethods}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PySlice}; +use pyo3::types::{IntoPyDict, PyList, PySlice}; use shakmaty::{Role, Square}; use std::collections::HashMap; @@ -97,88 +97,84 @@ pub struct PositionStatus { pub turn: bool, } -/// Flat array container for parsed chess games, optimized for ML training. -#[pyclass] -pub struct ParsedGames { - // === Board state arrays (N_positions) === - /// Board positions, shape (N_positions, 8, 8), dtype uint8 - #[pyo3(get)] - pub boards: Py, - - /// Castling rights [K,Q,k,q], shape (N_positions, 4), dtype bool - #[pyo3(get)] - pub castling: Py, - - /// En passant file (-1 if none), shape (N_positions,), dtype int8 - #[pyo3(get)] - pub en_passant: Py, - - /// Halfmove clock, shape (N_positions,), dtype uint8 - #[pyo3(get)] - pub halfmove_clock: Py, - - /// Side to move (true=white), shape (N_positions,), dtype bool - #[pyo3(get)] - pub turn: Py, - - // === Move arrays (N_moves) === - /// From squares, shape (N_moves,), dtype uint8 - #[pyo3(get)] - pub from_squares: Py, - - /// To squares, shape (N_moves,), dtype uint8 - #[pyo3(get)] - pub to_squares: Py, - - /// Promotions (-1=none, 2=N, 3=B, 4=R, 5=Q), shape (N_moves,), dtype int8 - #[pyo3(get)] - pub promotions: Py, - - /// Clock times in seconds (NaN if missing), shape (N_moves,), dtype float32 - #[pyo3(get)] - pub clocks: Py, - - /// Engine evals (NaN if missing), shape (N_moves,), dtype float32 - #[pyo3(get)] - pub evals: Py, - - // === Offsets === - /// Move offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 - /// Game i's moves: move_offsets[i]..move_offsets[i+1] - #[pyo3(get)] - pub move_offsets: Py, - - /// Position offsets for CSR-style indexing, shape (N_games + 1,), dtype uint32 - /// Game i's positions: position_offsets[i]..position_offsets[i+1] - #[pyo3(get)] - pub position_offsets: Py, - - // === Final position status (N_games) === - /// Final position is checkmate, shape (N_games,), dtype bool - #[pyo3(get)] - pub is_checkmate: Py, - - /// Final position is stalemate, shape (N_games,), dtype bool - #[pyo3(get)] - pub is_stalemate: Py, - - /// Insufficient material (white, black), shape (N_games, 2), dtype bool - #[pyo3(get)] - pub is_insufficient: Py, +/// Internal per-chunk data. Not exposed to Python directly. +/// Each chunk corresponds to one thread's output during parallel parsing. +pub struct ChunkData { + // Per-position numpy arrays + pub boards: Py, // (N_positions, 8, 8) u8 + pub castling: Py, // (N_positions, 4) bool + pub en_passant: Py, // (N_positions,) i8 + pub halfmove_clock: Py, // (N_positions,) u8 + pub turn: Py, // (N_positions,) bool + + // Per-move numpy arrays + pub from_squares: Py, // (N_moves,) u8 + pub to_squares: Py, // (N_moves,) u8 + pub promotions: Py, // (N_moves,) i8 + pub clocks: Py, // (N_moves,) f32 + pub evals: Py, // (N_moves,) f32 + + // Per-game arrays (CSR offsets are local to this chunk) + pub move_offsets: Py, // (N_games + 1,) u32 + pub position_offsets: Py, // (N_games + 1,) u32 + pub is_checkmate: Py, // (N_games,) bool + pub is_stalemate: Py, // (N_games,) bool + pub is_insufficient: Py, // (N_games, 2) bool + pub legal_move_count: Py, // (N_games,) u16 + pub valid: Py, // (N_games,) bool + pub headers: Vec>, - /// Legal move count in final position, shape (N_games,), dtype uint16 - #[pyo3(get)] - pub legal_move_count: Py, + // Metadata + pub num_games: usize, + pub num_moves: usize, + pub num_positions: usize, +} - // === Parse status (N_games) === - /// Whether game parsed successfully, shape (N_games,), dtype bool - #[pyo3(get)] - pub valid: Py, +/// Chunked container for parsed chess games, optimized for ML training. +/// +/// Internally stores data in multiple chunks (one per parsing thread) to +/// avoid the cost of merging. Per-game access is O(log(num_chunks)) via +/// binary search on precomputed boundaries. +#[pyclass] +pub struct ParsedGames { + pub chunks: Vec, + + // Global prefix sums for O(1) chunk lookup. + // game_boundaries[i] = total games in chunks 0..i + // So game_boundaries = [0, chunk0.num_games, chunk0+chunk1, ..., total_games] + pub game_boundaries: Vec, + pub move_boundaries: Vec, + pub position_boundaries: Vec, + + pub total_games: usize, + pub total_moves: usize, + pub total_positions: usize, + + // Global offsets for position_to_game / move_to_game (precomputed) + pub global_move_offsets: Py, + pub global_position_offsets: Py, +} - // === Raw headers (N_games) === - /// Raw PGN headers as list of dicts - #[pyo3(get)] - pub headers: Vec>, +impl ParsedGames { + /// Locate which chunk a global game index belongs to. + /// Returns (chunk_index, local_game_index). + fn locate_game(&self, global_idx: usize) -> (usize, usize) { + // Binary search: find the last boundary <= global_idx + // game_boundaries = [0, n0, n0+n1, ...], length = num_chunks + 1 + let chunk_idx = match self.game_boundaries.binary_search(&global_idx) { + Ok(i) => { + // Exact match. If it's the last boundary, back up one. + if i >= self.chunks.len() { + self.chunks.len() - 1 + } else { + i + } + } + Err(i) => i - 1, // insertion point - 1 + }; + let local_idx = global_idx - self.game_boundaries[chunk_idx]; + (chunk_idx, local_idx) + } } #[pymethods] @@ -186,39 +182,36 @@ impl ParsedGames { /// Number of games in the result. #[getter] fn num_games(&self) -> usize { - self.headers.len() + self.total_games } /// Total number of moves across all games. #[getter] - fn num_moves(&self, py: Python<'_>) -> PyResult { - let offsets = self.move_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let readonly = offsets.readonly(); - let slice = readonly.as_slice()?; - Ok(slice.last().copied().unwrap_or(0) as usize) + fn num_moves(&self) -> usize { + self.total_moves } /// Total number of board positions recorded. #[getter] - fn num_positions(&self, py: Python<'_>) -> PyResult { - let offsets = self.position_offsets.bind(py); - let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - let readonly = offsets.readonly(); - let slice = readonly.as_slice()?; - Ok(slice.last().copied().unwrap_or(0) as usize) + fn num_positions(&self) -> usize { + self.total_positions + } + + /// Number of internal chunks. + #[getter] + fn num_chunks(&self) -> usize { + self.chunks.len() } fn __len__(&self) -> usize { - self.headers.len() + self.total_games } fn __getitem__(slf: Py, py: Python<'_>, idx: &Bound<'_, PyAny>) -> PyResult> { - let n_games = slf.borrow(py).headers.len(); + let n_games = slf.borrow(py).total_games; // Handle integer index if let Ok(mut i) = idx.extract::() { - // Handle negative indexing if i < 0 { i += n_games as isize; } @@ -232,14 +225,13 @@ impl ParsedGames { return Ok(Py::new(py, game_view)?.into_any()); } - // Handle slice + // Handle slice — returns list of PyGameView if let Ok(slice) = idx.cast::() { let indices = slice.indices(n_games as isize)?; let start = indices.start as usize; let stop = indices.stop as usize; let step = indices.step as usize; - // For simplicity, we return a list of PyGameView objects let mut views: Vec> = Vec::new(); let mut i = start; while i < stop { @@ -247,7 +239,7 @@ impl ParsedGames { views.push(Py::new(py, game_view)?); i += step; } - return Ok(pyo3::types::PyList::new(py, views)?.into_any().unbind()); + return Ok(PyList::new(py, views)?.into_any().unbind()); } Err(pyo3::exceptions::PyTypeError::new_err(format!( @@ -257,20 +249,40 @@ impl ParsedGames { } fn __iter__(slf: Py, py: Python<'_>) -> PyResult { - let n_games = slf.borrow(py).headers.len(); + let total = slf.borrow(py).total_games; Ok(ParsedGamesIter { data: slf, index: 0, - length: n_games, + length: total, }) } + /// Escape hatch: access raw per-chunk data. + /// + /// Returns a list of chunk view objects, each exposing numpy arrays + /// for that chunk's data. Use this for advanced/custom access patterns. + #[getter] + fn chunks(slf: Py, py: Python<'_>) -> PyResult> { + let n_chunks = slf.borrow(py).chunks.len(); + let mut views: Vec> = Vec::with_capacity(n_chunks); + for i in 0..n_chunks { + views.push(Py::new( + py, + PyChunkView { + parent: slf.clone_ref(py), + chunk_idx: i, + }, + )?); + } + Ok(PyList::new(py, views)?.into_any().unbind()) + } + /// Map position indices to game indices. /// /// Useful after shuffling/sampling positions to look up game metadata. /// /// Args: - /// position_indices: Array of indices into boards array. + /// position_indices: Array of indices into the global position space. /// Accepts any integer dtype; int64 is optimal (avoids conversion). /// /// Returns: @@ -280,13 +292,11 @@ impl ParsedGames { py: Python<'py>, position_indices: &Bound<'py, PyAny>, ) -> PyResult>> { - let offsets = self.position_offsets.bind(py); + let offsets = self.global_position_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; - // Get numpy module for searchsorted let numpy = py.import("numpy")?; - // Convert input to int64 array (no-op if already int64) let int64_dtype = numpy.getattr("int64")?; let position_indices = numpy .call_method1("asarray", (position_indices,))? @@ -296,12 +306,10 @@ impl ParsedGames { Some(&[("copy", false)].into_py_dict(py)?), )?; - // offsets[:-1] - all but last element let len = offsets.len()?; let slice_obj = PySlice::new(py, 0, (len - 1) as isize, 1); let offsets_slice = offsets.call_method1("__getitem__", (slice_obj,))?; - // searchsorted(offsets[:-1], position_indices, side='right') - 1 let result = numpy.call_method1( "searchsorted", ( @@ -311,7 +319,6 @@ impl ParsedGames { ), )?; - // Subtract 1 let one = 1i64.into_pyobject(py)?; let result = result.call_method1("__sub__", (one,))?; @@ -321,7 +328,7 @@ impl ParsedGames { /// Map move indices to game indices. /// /// Args: - /// move_indices: Array of indices into from_squares, to_squares, etc. + /// move_indices: Array of indices into the global move space. /// Accepts any integer dtype; int64 is optimal (avoids conversion). /// /// Returns: @@ -331,12 +338,11 @@ impl ParsedGames { py: Python<'py>, move_indices: &Bound<'py, PyAny>, ) -> PyResult>> { - let offsets = self.move_offsets.bind(py); + let offsets = self.global_move_offsets.bind(py); let offsets: &Bound<'_, PyArray1> = offsets.cast()?; let numpy = py.import("numpy")?; - // Convert input to int64 array (no-op if already int64) let int64_dtype = numpy.getattr("int64")?; let move_indices = numpy .call_method1("asarray", (move_indices,))? @@ -366,6 +372,170 @@ impl ParsedGames { } } +/// Escape hatch: view into a single chunk's raw numpy arrays. +/// +/// Access via `parsed_games.chunks[i]`. Each chunk corresponds to one +/// parsing thread's output. Use this for advanced access patterns like +/// manual concatenation or custom batching. +#[pyclass] +pub struct PyChunkView { + parent: Py, + chunk_idx: usize, +} + +#[pymethods] +impl PyChunkView { + #[getter] + fn num_games(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_games + } + + #[getter] + fn num_moves(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_moves + } + + #[getter] + fn num_positions(&self, py: Python<'_>) -> usize { + self.parent.borrow(py).chunks[self.chunk_idx].num_positions + } + + #[getter] + fn boards(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .boards + .clone_ref(py) + } + + #[getter] + fn castling(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .castling + .clone_ref(py) + } + + #[getter] + fn en_passant(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .en_passant + .clone_ref(py) + } + + #[getter] + fn halfmove_clock(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .halfmove_clock + .clone_ref(py) + } + + #[getter] + fn turn(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .turn + .clone_ref(py) + } + + #[getter] + fn from_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .from_squares + .clone_ref(py) + } + + #[getter] + fn to_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .to_squares + .clone_ref(py) + } + + #[getter] + fn promotions(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .promotions + .clone_ref(py) + } + + #[getter] + fn clocks(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .clocks + .clone_ref(py) + } + + #[getter] + fn evals(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .evals + .clone_ref(py) + } + + #[getter] + fn move_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .move_offsets + .clone_ref(py) + } + + #[getter] + fn position_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .position_offsets + .clone_ref(py) + } + + #[getter] + fn is_checkmate(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_checkmate + .clone_ref(py) + } + + #[getter] + fn is_stalemate(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_stalemate + .clone_ref(py) + } + + #[getter] + fn is_insufficient(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .is_insufficient + .clone_ref(py) + } + + #[getter] + fn legal_move_count(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_count + .clone_ref(py) + } + + #[getter] + fn valid(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .valid + .clone_ref(py) + } + + #[getter] + fn headers(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .headers + .clone() + } + + fn __repr__(&self, py: Python<'_>) -> String { + let borrowed = self.parent.borrow(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + format!( + "", + self.chunk_idx, chunk.num_games, chunk.num_moves, chunk.num_positions + ) + } +} + /// Iterator over games in a ParsedGames result. #[pyclass] pub struct ParsedGamesIter { @@ -399,20 +569,30 @@ impl ParsedGamesIter { #[pyclass] pub struct PyGameView { data: Py, - idx: usize, + /// Index of the chunk this game lives in. + chunk_idx: usize, + /// Local game index within the chunk. + local_idx: usize, + /// Move range within the chunk's move arrays. move_start: usize, move_end: usize, + /// Position range within the chunk's position arrays. pos_start: usize, pos_end: usize, } impl PyGameView { - pub fn new(py: Python<'_>, data: Py, idx: usize) -> PyResult { + /// Create a new game view for global game index `global_idx`. + pub fn new(py: Python<'_>, data: Py, global_idx: usize) -> PyResult { let borrowed = data.borrow(py); - let move_offsets = borrowed.move_offsets.bind(py); + let (chunk_idx, local_idx) = borrowed.locate_game(global_idx); + let chunk = &borrowed.chunks[chunk_idx]; + + // Read this chunk's local CSR offsets + let move_offsets = chunk.move_offsets.bind(py); let move_offsets: &Bound<'_, PyArray1> = move_offsets.cast()?; - let pos_offsets = borrowed.position_offsets.bind(py); + let pos_offsets = chunk.position_offsets.bind(py); let pos_offsets: &Bound<'_, PyArray1> = pos_offsets.cast()?; let move_offsets_ro = move_offsets.readonly(); @@ -421,23 +601,24 @@ impl PyGameView { let pos_offsets_slice = pos_offsets_ro.as_slice()?; let move_start = move_offsets_slice - .get(idx) + .get(local_idx) .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; let move_end = move_offsets_slice - .get(idx + 1) + .get(local_idx + 1) .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; let pos_start = pos_offsets_slice - .get(idx) + .get(local_idx) .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; let pos_end = pos_offsets_slice - .get(idx + 1) + .get(local_idx + 1) .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index"))?; drop(borrowed); Ok(Self { data, - idx, + chunk_idx, + local_idx, move_start: *move_start as usize, move_end: *move_end as usize, pos_start: *pos_start as usize, @@ -465,7 +646,7 @@ impl PyGameView { #[getter] fn boards<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); let slice = boards.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -475,7 +656,7 @@ impl PyGameView { #[getter] fn initial_board<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); let slice = boards.call_method1("__getitem__", (self.pos_start,))?; Ok(slice.unbind()) } @@ -484,7 +665,7 @@ impl PyGameView { #[getter] fn final_board<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let boards = borrowed.boards.bind(py); + let boards = borrowed.chunks[self.chunk_idx].boards.bind(py); let slice = boards.call_method1("__getitem__", (self.pos_end - 1,))?; Ok(slice.unbind()) } @@ -493,7 +674,7 @@ impl PyGameView { #[getter] fn castling<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.castling.bind(py); + let arr = borrowed.chunks[self.chunk_idx].castling.bind(py); let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -503,7 +684,7 @@ impl PyGameView { #[getter] fn en_passant<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.en_passant.bind(py); + let arr = borrowed.chunks[self.chunk_idx].en_passant.bind(py); let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -513,7 +694,7 @@ impl PyGameView { #[getter] fn halfmove_clock<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.halfmove_clock.bind(py); + let arr = borrowed.chunks[self.chunk_idx].halfmove_clock.bind(py); let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -523,7 +704,7 @@ impl PyGameView { #[getter] fn turn<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.turn.bind(py); + let arr = borrowed.chunks[self.chunk_idx].turn.bind(py); let slice_obj = PySlice::new(py, self.pos_start as isize, self.pos_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -535,7 +716,7 @@ impl PyGameView { #[getter] fn from_squares<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.from_squares.bind(py); + let arr = borrowed.chunks[self.chunk_idx].from_squares.bind(py); let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -545,7 +726,7 @@ impl PyGameView { #[getter] fn to_squares<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.to_squares.bind(py); + let arr = borrowed.chunks[self.chunk_idx].to_squares.bind(py); let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -555,7 +736,7 @@ impl PyGameView { #[getter] fn promotions<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.promotions.bind(py); + let arr = borrowed.chunks[self.chunk_idx].promotions.bind(py); let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -565,7 +746,7 @@ impl PyGameView { #[getter] fn clocks<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.clocks.bind(py); + let arr = borrowed.chunks[self.chunk_idx].clocks.bind(py); let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -575,7 +756,7 @@ impl PyGameView { #[getter] fn evals<'py>(&self, py: Python<'py>) -> PyResult> { let borrowed = self.data.borrow(py); - let arr = borrowed.evals.bind(py); + let arr = borrowed.chunks[self.chunk_idx].evals.bind(py); let slice_obj = PySlice::new(py, self.move_start as isize, self.move_end as isize, 1); let slice = arr.call_method1("__getitem__", (slice_obj,))?; Ok(slice.unbind()) @@ -587,19 +768,19 @@ impl PyGameView { #[getter] fn headers(&self, py: Python<'_>) -> PyResult> { let borrowed = self.data.borrow(py); - Ok(borrowed.headers[self.idx].clone()) + Ok(borrowed.chunks[self.chunk_idx].headers[self.local_idx].clone()) } /// Final position is checkmate. #[getter] fn is_checkmate(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); - let arr = borrowed.is_checkmate.bind(py); + let arr = borrowed.chunks[self.chunk_idx].is_checkmate.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; let readonly = arr.readonly(); let slice = readonly.as_slice()?; slice - .get(self.idx) + .get(self.local_idx) .copied() .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } @@ -608,12 +789,12 @@ impl PyGameView { #[getter] fn is_stalemate(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); - let arr = borrowed.is_stalemate.bind(py); + let arr = borrowed.chunks[self.chunk_idx].is_stalemate.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; let readonly = arr.readonly(); let slice = readonly.as_slice()?; slice - .get(self.idx) + .get(self.local_idx) .copied() .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } @@ -622,12 +803,11 @@ impl PyGameView { #[getter] fn is_insufficient(&self, py: Python<'_>) -> PyResult<(bool, bool)> { let borrowed = self.data.borrow(py); - let arr = borrowed.is_insufficient.bind(py); + let arr = borrowed.chunks[self.chunk_idx].is_insufficient.bind(py); let arr: &Bound<'_, PyArray2> = arr.cast()?; let readonly = arr.readonly(); let slice = readonly.as_slice()?; - // Array is shape (n_games, 2), so index is idx * 2 for white, idx * 2 + 1 for black - let base = self.idx * 2; + let base = self.local_idx * 2; let white = slice .get(base) .copied() @@ -643,12 +823,12 @@ impl PyGameView { #[getter] fn legal_move_count(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); - let arr = borrowed.legal_move_count.bind(py); + let arr = borrowed.chunks[self.chunk_idx].legal_move_count.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; let readonly = arr.readonly(); let slice = readonly.as_slice()?; slice - .get(self.idx) + .get(self.local_idx) .copied() .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } @@ -657,12 +837,12 @@ impl PyGameView { #[getter] fn is_valid(&self, py: Python<'_>) -> PyResult { let borrowed = self.data.borrow(py); - let arr = borrowed.valid.bind(py); + let arr = borrowed.chunks[self.chunk_idx].valid.bind(py); let arr: &Bound<'_, PyArray1> = arr.cast()?; let readonly = arr.readonly(); let slice = readonly.as_slice()?; slice - .get(self.idx) + .get(self.local_idx) .copied() .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } @@ -680,11 +860,12 @@ impl PyGameView { } let borrowed = self.data.borrow(py); - let from_arr = borrowed.from_squares.bind(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + let from_arr = chunk.from_squares.bind(py); let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; - let to_arr = borrowed.to_squares.bind(py); + let to_arr = chunk.to_squares.bind(py); let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; - let promo_arr = borrowed.promotions.bind(py); + let promo_arr = chunk.promotions.bind(py); let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; let from_ro = from_arr.readonly(); diff --git a/src/test.py b/src/test.py index 9aa41ad..cec9733 100644 --- a/src/test.py +++ b/src/test.py @@ -692,26 +692,34 @@ def test_basic_structure(self): "1. d4 d5 2. c4 e6 0-1", ] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + # Use 1 thread to get a single chunk for predictable array shapes + result = rust_pgn_reader_python_binding.parse_games_flat(chunked, num_threads=1) # Check game count self.assertEqual(len(result), 2) + self.assertEqual(result.num_games, 2) + self.assertEqual(result.num_moves, 9) + self.assertEqual(result.num_positions, 11) # 9 moves + 2 initial positions - # Check move offsets - self.assertEqual(len(result.move_offsets), 3) - self.assertEqual(result.move_offsets[0], 0) - self.assertEqual(result.move_offsets[1], 5) # Game 1: 5 half-moves - self.assertEqual(result.move_offsets[2], 9) # Game 2: 4 half-moves + # Check per-game structure via game views + game0 = result[0] + self.assertEqual(len(game0), 5) # Game 1: 5 half-moves + self.assertEqual(game0.num_positions, 6) - # Check shapes - total_moves = 9 - total_positions = 9 + 2 # moves + initial positions + game1 = result[1] + self.assertEqual(len(game1), 4) # Game 2: 4 half-moves + self.assertEqual(game1.num_positions, 5) - self.assertEqual(result.boards.shape, (total_positions, 8, 8)) - self.assertEqual(result.castling.shape, (total_positions, 4)) - self.assertEqual(result.en_passant.shape, (total_positions,)) - self.assertEqual(result.from_squares.shape, (total_moves,)) - self.assertEqual(result.valid.shape, (2,)) + # With 1 thread, all games are in a single chunk + self.assertEqual(result.num_chunks, 1) + chunk = result.chunks[0] + total_moves = 9 + total_positions = 9 + 2 + self.assertEqual(chunk.boards.shape, (total_positions, 8, 8)) + self.assertEqual(chunk.castling.shape, (total_positions, 4)) + self.assertEqual(chunk.en_passant.shape, (total_positions,)) + self.assertEqual(chunk.from_squares.shape, (total_moves,)) + self.assertEqual(chunk.valid.shape, (2,)) def test_initial_board_encoding(self): """Test initial board state encoding.""" @@ -719,7 +727,7 @@ def test_initial_board_encoding(self): chunked = pa.chunked_array([pa.array(pgns)]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) - initial = result.boards[0] # First position + initial = result[0].initial_board # First position # Encoding: 0=empty, 1=P, 2=N, 3=B, 4=R, 5=Q, 6=K, +6 for black # Square indexing: a1=0, b1=1, ..., h1=7, a2=8, ... @@ -746,7 +754,7 @@ def test_board_after_move(self): result = rust_pgn_reader_python_binding.parse_games_flat(chunked) # Position 0: initial, Position 1: after e4 - after_e4 = result.boards[1] + after_e4 = result[0].boards[1] # e2 (index 12) should be empty self.assertEqual(after_e4.flat[12], 0) @@ -759,10 +767,11 @@ def test_en_passant_tracking(self): chunked = pa.chunked_array([pa.array(pgns)]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + game = result[0] # Initial: no en passant - self.assertEqual(result.en_passant[0], -1) + self.assertEqual(game.en_passant[0], -1) # After e4: en passant on e-file (file index 4) - self.assertEqual(result.en_passant[1], 4) + self.assertEqual(game.en_passant[1], 4) def test_castling_rights(self): """Test castling rights tracking.""" @@ -771,15 +780,16 @@ def test_castling_rights(self): chunked = pa.chunked_array([pa.array(pgns)]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + game = result[0] # Initial: all castling [K, Q, k, q] = [True, True, True, True] - self.assertTrue(all(result.castling[0])) + self.assertTrue(all(game.castling[0])) # After Rg1 (position 5): white kingside lost # Castling order: [K, Q, k, q] - self.assertFalse(result.castling[5, 0]) # White K - self.assertTrue(result.castling[5, 1]) # White Q - self.assertTrue(result.castling[5, 2]) # Black k - self.assertTrue(result.castling[5, 3]) # Black q + self.assertFalse(game.castling[5, 0]) # White K + self.assertTrue(game.castling[5, 1]) # White Q + self.assertTrue(game.castling[5, 2]) # Black k + self.assertTrue(game.castling[5, 3]) # Black q def test_turn_tracking(self): """Test side-to-move tracking.""" @@ -787,12 +797,13 @@ def test_turn_tracking(self): chunked = pa.chunked_array([pa.array(pgns)]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + game = result[0] # Initial: white to move - self.assertTrue(result.turn[0]) + self.assertTrue(game.turn[0]) # After e4: black to move - self.assertFalse(result.turn[1]) + self.assertFalse(game.turn[1]) # After e5: white to move - self.assertTrue(result.turn[2]) + self.assertTrue(game.turn[2]) def test_game_view_access(self): """Test GameView provides correct slices.""" @@ -897,10 +908,11 @@ def test_clocks_and_evals(self): chunked = pa.chunked_array([pa.array([pgn])]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) - self.assertAlmostEqual(result.evals[0], 0.17, places=2) - self.assertAlmostEqual(result.evals[1], 0.19, places=2) - self.assertAlmostEqual(result.clocks[0], 30.0, places=1) - self.assertAlmostEqual(result.clocks[1], 29.0, places=1) + game = result[0] + self.assertAlmostEqual(game.evals[0], 0.17, places=2) + self.assertAlmostEqual(game.evals[1], 0.19, places=2) + self.assertAlmostEqual(game.clocks[0], 30.0, places=1) + self.assertAlmostEqual(game.clocks[1], 29.0, places=1) def test_missing_clocks_evals_are_nan(self): """Test missing clocks/evals are NaN.""" @@ -908,8 +920,9 @@ def test_missing_clocks_evals_are_nan(self): chunked = pa.chunked_array([pa.array(pgns)]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) - self.assertTrue(np.isnan(result.clocks[0])) - self.assertTrue(np.isnan(result.evals[0])) + game = result[0] + self.assertTrue(np.isnan(game.clocks[0])) + self.assertTrue(np.isnan(game.evals[0])) def test_headers_preserved(self): """Test headers are preserved as dicts.""" @@ -937,9 +950,9 @@ def test_invalid_game_flagged(self): result = rust_pgn_reader_python_binding.parse_games_flat(chunked) self.assertEqual(len(result), 3) - self.assertTrue(result.valid[0]) - self.assertFalse(result.valid[1]) - self.assertTrue(result.valid[2]) + self.assertTrue(result[0].is_valid) + self.assertFalse(result[1].is_valid) + self.assertTrue(result[2].is_valid) def test_checkmate_detection(self): """Test checkmate is detected.""" @@ -962,10 +975,9 @@ def test_promotion(self): chunked = pa.chunked_array([pa.array([pgn])]) result = rust_pgn_reader_python_binding.parse_games_flat(chunked) - # Promotion to queen = 5 - self.assertEqual(result.promotions[0], 5) - game = result[0] + # Promotion to queen = 5 + self.assertEqual(game.promotions[0], 5) self.assertEqual(game.move_uci(0), "a7a8q") def test_num_properties(self): From e082bca98a9331373835f1ec09a3740637175462 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Fri, 6 Feb 2026 12:27:11 -0500 Subject: [PATCH 25/31] Tweak bench to be more apples-apples --- src/bench_parse_games_flat.py | 61 +++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games_flat.py index cf72ac7..65bf7ab 100644 --- a/src/bench_parse_games_flat.py +++ b/src/bench_parse_games_flat.py @@ -168,15 +168,6 @@ def benchmark_data_access_flat(result) -> dict: _ = chunk.from_squares.sum() _ = chunk.to_squares.sum() - # Per-game access pattern (iterate and access boards) - for i in range(min(1000, result.num_games)): - game = result[i] - _ = game.boards - - # Position-to-game mapping (still works globally) - indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) - _ = result.position_to_game(indices) - elapsed = time.perf_counter() - start return {"access_time": elapsed} @@ -195,6 +186,37 @@ def benchmark_data_access_extractors(extractors: list) -> dict: return {"access_time": elapsed} +def benchmark_per_game_access_flat(result) -> dict: + """Benchmark per-game access pattern for parse_games_flat result.""" + start = time.perf_counter() + + # Per-game access pattern (iterate and access boards) + for i in range(min(1000, result.num_games)): + game = result[i] + _ = game.boards + + elapsed = time.perf_counter() - start + return { + "access_time": elapsed, + "games_accessed": min(1000, result.num_games), + } + + +def benchmark_position_to_game_mapping(result) -> dict: + """Benchmark position-to-game mapping for parse_games_flat result.""" + start = time.perf_counter() + + # Position-to-game mapping (still works globally) + indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) + _ = result.position_to_game(indices) + + elapsed = time.perf_counter() - start + return { + "access_time": elapsed, + "positions_accessed": 1000, + } + + def format_number(n: float) -> str: """Format large numbers with K/M suffix.""" if n >= 1_000_000: @@ -321,6 +343,27 @@ def main(): access_speedup = extractor_access["access_time"] / flat_access["access_time"] print(f"\nFlat arrays are {access_speedup:.2f}x faster for data access") + # Per-game access benchmarks + print("\n" + "=" * 60) + print("PER-GAME ACCESS BENCHMARKS") + print("=" * 60) + + per_game_access = benchmark_per_game_access_flat(flat_results["result"]) + print(f"\nPer-game access time: {per_game_access['access_time']:.3f}s") + print(f"Games accessed: {per_game_access['games_accessed']:,}") + print( + f"Time per game: {per_game_access['access_time'] / per_game_access['games_accessed'] * 1000:.3f}ms" + ) + + position_mapping = benchmark_position_to_game_mapping(flat_results["result"]) + print(f"\nPosition-to-game mapping: {position_mapping['access_time']:.3f}s") + print( + f"Positions mapped: {position_mapping['positions_accessed']:,}" + ) + print( + f"Time per lookup: {position_mapping['access_time'] / position_mapping['positions_accessed'] * 1000000:.3f}μs" + ) + # Memory usage (approximate) print("\n" + "=" * 60) print("MEMORY USAGE (approximate)") From 4f68999dc11479e1aaea0807611a3d8601af2391 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 10:36:37 -0500 Subject: [PATCH 26/31] swapping the main API to use the "flat" Structure of Arrays implementation. --- README.md | 34 +- benches/parquet_bench.rs | 112 +- rust_pgn_reader_python_binding.pyi | 130 +- src/bench_parquet.py | 8 +- src/bench_parquet_arrow.py | 8 +- src/bench_parquet_parallel.py | 4 +- ...rse_games_flat.py => bench_parse_games.py} | 145 +- src/{bench.py => bench_pgn_file.py} | 13 +- src/example_chess960.py | 18 +- src/example_complex.py | 41 + ...e_games_flat.py => example_parse_games.py} | 8 +- src/example_rowgroup.py | 7 +- src/example_str.py | 25 +- src/flat_visitor.rs | 311 ++++- src/lib.rs | 384 ++---- src/python_bindings.rs | 297 ++--- src/test.py | 1169 +++++++---------- src/visitor.rs | 447 ------- 18 files changed, 1215 insertions(+), 1946 deletions(-) rename src/{bench_parse_games_flat.py => bench_parse_games.py} (64%) rename src/{bench.py => bench_pgn_file.py} (61%) create mode 100644 src/example_complex.py rename src/{example_parse_games_flat.py => example_parse_games.py} (95%) delete mode 100644 src/visitor.rs diff --git a/README.md b/README.md index b66c9a6..aa9fe12 100644 --- a/README.md +++ b/README.md @@ -5,26 +5,34 @@ This project adds Python bindings to [rust-pgn-reader](https://github.com/niklas ## Installing `pip install rust_pgn_reader_python_binding` +## API + +Three entry points are available: + +- `parse_game(pgn)` - Parse a single PGN string +- `parse_games(chunked_array)` - Parse games from a PyArrow ChunkedArray (multithreaded) +- `parse_games_from_strings(pgns)` - Parse a list of PGN strings (multithreaded) + +All return a `ParsedGames` container with flat NumPy arrays, supporting: +- Indexing (`result[i]`), slicing (`result[1:3]`), and iteration (`for game in result`) +- Per-game views (`PyGameView`) with zero-copy array slices +- Position-to-game and move-to-game mapping for ML workflows +- Optional comment storage (`store_comments=True`) +- Optional legal move storage (`store_legal_moves=True`) + ## Benchmarks -Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on an 7800X3D. +Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on a 7800X3D. | Parser | File format | Time | |----------------------------------------------------------------------------|-------------|--------| | [rust-pgn-reader](https://github.com/niklasf/rust-pgn-reader/tree/master) | PGN | 1s | -| rust_pgn_reader_python_binding | PGN | 4.7s | -| rust_pgn_reader_python_binding, parse_game (single_threaded) | parquet | 3.3s | -| rust_pgn_reader_python_binding, parse_games (multithreaded) | parquet | 0.5s | -| rust_pgn_reader_python_binding, parse_game_moves_arrow_chunked_array (multithreaded) | parquet | 0.35s | +| rust_pgn_reader_python_binding, parse_games (multithreaded) | parquet | 0.35s | | [chess-library](https://github.com/Disservin/chess-library) | PGN | 2s | | [python-chess](https://github.com/niklasf/python-chess) | PGN | 3+ min | To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: -`python bench_parquet.py` (single-threaded parse_game) - -`python bench_parquet_parallel.py` (multithreaded parse_games) - -`python bench_parquet_arrow.py` (multithreaded parse_game_moves_arrow_chunked_array) +`python src/bench_parse_games.py 2013-07-train-00000-of-00001.parquet` ## Building `maturin develop` @@ -34,12 +42,12 @@ To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: For a more thorough tutorial, follow https://lukesalamone.github.io/posts/how-to-create-rust-python-bindings/ ## Profiling -`py-spy record -s -F -f speedscope --output profile.speedscope -- python ./src/bench_parquet.py` +`py-spy record -s -F -f speedscope --output profile.speedscope -- python ./src/bench_parse_games.py` Linux/WSL-only: -`py-spy record -s -F -n -f speedscope --output profile.speedscope -- python ./src/bench_parquet.py` +`py-spy record -s -F -n -f speedscope --output profile.speedscope -- python ./src/bench_parse_games.py` ## Testing `cargo test` -`python -m unittest src/test.py` \ No newline at end of file +`python -m unittest src/test.py` diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 3e5347e..4af6b63 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -1,4 +1,4 @@ -//! Benchmark for PGN parsing APIs, designed to mirror the Python workflow. +//! Benchmark for PGN parsing API, designed to mirror the Python workflow. //! //! `cargo bench --bench parquet_bench` //! `samply record --rate 10000 cargo bench --bench parquet_bench` @@ -11,13 +11,13 @@ use std::fs::File; use std::path::Path; use std::time::Instant; -use rust_pgn_reader_python_binding::{parse_game_to_flat, parse_single_game_native, FlatBuffers}; +use rust_pgn_reader_python_binding::{parse_game_to_buffers, Buffers, ParseConfig}; const FILE_PATH: &str = "2013-07-train-00000-of-00001.parquet"; -/// Chunk multiplier for explicit chunking in Flat API. -/// 1 = exactly num_threads chunks (minimal merge overhead) -/// Higher values provide better load balancing at cost of more buffers to merge. +/// Chunk multiplier for explicit chunking. +/// 1 = exactly num_threads chunks (minimal overhead) +/// Higher values provide better load balancing at cost of more buffers. const CHUNK_MULTIPLIER: usize = 1; /// Read parquet file and return the raw Arrow StringArrays. @@ -40,7 +40,6 @@ fn read_parquet_to_string_arrays(file_path: &str) -> Vec { .column_by_name("movetext") .and_then(|col| col.as_any().downcast_ref::()) { - // Clone the StringArray to own it (Arrow uses Arc internally, so this is cheap) arrays.push(array.clone()); } else { panic!("movetext column not found or not a StringArray"); @@ -50,8 +49,6 @@ fn read_parquet_to_string_arrays(file_path: &str) -> Vec { } /// Extract &str slices from Arrow StringArrays (zero-copy). -/// This mirrors the extraction logic in `_parse_game_moves_from_arrow_chunks_native` -/// and `parse_games_flat` in lib.rs. fn extract_str_slices<'a>(arrays: &'a [StringArray]) -> Vec<&'a str> { let total_len: usize = arrays.iter().map(|a| a.len()).sum(); let mut slices = Vec::with_capacity(total_len); @@ -66,72 +63,19 @@ fn extract_str_slices<'a>(arrays: &'a [StringArray]) -> Vec<&'a str> { slices } -/// Benchmark the Arrow API workflow. +/// Benchmark the parsing API workflow. /// -/// This mirrors `parse_game_moves_arrow_chunked_array()` from Python: -/// 1. Read parquet to Arrow arrays -/// 2. Extract &str slices from StringArray (like the Python-bound function does) -/// 3. Parse each game in parallel → Vec -/// -/// Args: -/// - store_board_states: Whether to populate board state vectors (for benchmarking overhead) -pub fn bench_arrow_api(store_board_states: bool) { - // Step 1: Read parquet to Arrow StringArrays - let arrays = read_parquet_to_string_arrays(FILE_PATH); - - // Step 2: Extract &str slices (zero-copy, mirrors Arrow chunk iteration) - let pgn_slices = extract_str_slices(&arrays); - println!("Read {} games from parquet.", pgn_slices.len()); - println!("store_board_states: {}", store_board_states); - - // Step 3: Build thread pool (same pattern as lib.rs) - let num_threads = num_cpus::get(); - let thread_pool = ThreadPoolBuilder::new() - .num_threads(num_threads) - .build() - .expect("Failed to build Rayon thread pool"); - - // Step 4: Parse in parallel → Vec - // This mirrors _parse_game_moves_from_arrow_chunks_native - let start = Instant::now(); - - let extractors: Vec<_> = thread_pool - .install(|| { - pgn_slices - .par_iter() - .map(|&pgn| parse_single_game_native(pgn, false, store_board_states)) - .collect::, _>>() - }) - .expect("Parsing failed"); - - let parsing_duration = start.elapsed(); - - // Report results - let total_moves: usize = extractors.iter().map(|e| e.moves.len()).sum(); - let num_games = extractors.len(); - println!("Parsing time: {:?}", parsing_duration); - println!("Parsed {} games, {} total moves.", num_games, total_moves); - - // Explicitly drop to measure cleanup time - // This is the cost Python will pay when the list goes out of scope - let drop_start = Instant::now(); - drop(extractors); - let drop_duration = drop_start.elapsed(); - - let total_duration = start.elapsed(); - println!("Cleanup time (drop): {:?}", drop_duration); - println!("Total time (parsing + cleanup): {:?}", total_duration); -} - -/// Benchmark the Flat API workflow. -/// -/// This mirrors `parse_games_flat()` from Python: /// 1. Read parquet to Arrow arrays /// 2. Extract &str slices from StringArray -/// 3. Parse in parallel with explicit chunking (par_chunks) → fixed number of FlatBuffers +/// 3. Parse in parallel with explicit chunking (par_chunks) -> fixed number of Buffers /// -/// No merge step — the chunked architecture keeps per-thread buffers as-is. -pub fn bench_flat_api() { +/// No merge step - the chunked architecture keeps per-thread buffers as-is. +pub fn bench_parse_api() { + let config = ParseConfig { + store_comments: false, + store_legal_moves: false, + }; + // Step 1: Read parquet to Arrow StringArrays let arrays = read_parquet_to_string_arrays(FILE_PATH); @@ -142,11 +86,11 @@ pub fn bench_flat_api() { // Step 3: Build thread pool and compute capacity estimates let num_threads = num_cpus::get(); let n_games = pgn_slices.len(); - let moves_per_game = 70; // Estimate ~70 moves per game + let moves_per_game = 70; // Calculate chunk size for explicit chunking let num_chunks = num_threads * CHUNK_MULTIPLIER; - let chunk_size = (n_games + num_chunks - 1) / num_chunks; // ceiling division + let chunk_size = (n_games + num_chunks - 1) / num_chunks; let chunk_size = chunk_size.max(1); let games_per_chunk = chunk_size; @@ -160,18 +104,16 @@ pub fn bench_flat_api() { .build() .expect("Failed to build Rayon thread pool"); - // Step 4: Parse in parallel using par_chunks for explicit, fixed-size chunking - // This creates exactly ceil(n_games / chunk_size) FlatBuffers instances, - // avoiding the allocation storm from Rayon's dynamic work-stealing. + // Step 4: Parse in parallel using par_chunks let start = Instant::now(); - let chunk_results: Vec = thread_pool.install(|| { + let chunk_results: Vec = thread_pool.install(|| { pgn_slices .par_chunks(chunk_size) .map(|chunk| { - let mut buffers = FlatBuffers::with_capacity(games_per_chunk, moves_per_game); + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); for &pgn in chunk { - let _ = parse_game_to_flat(pgn, &mut buffers); + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); } buffers }) @@ -181,7 +123,7 @@ pub fn bench_flat_api() { let duration_parallel = start.elapsed(); println!("Parallel parsing time: {:?}", duration_parallel); println!( - "Created {} FlatBuffers chunks (no merge needed)", + "Created {} Buffers chunks (no merge needed)", chunk_results.len() ); @@ -196,7 +138,7 @@ pub fn bench_flat_api() { total_games, total_positions ); - // Measure cleanup time for fair comparison with Arrow API + // Measure cleanup time let drop_start = Instant::now(); drop(chunk_results); let drop_duration = drop_start.elapsed(); @@ -207,12 +149,6 @@ pub fn bench_flat_api() { } fn main() { - println!("=== Arrow API (MoveExtractor, store_board_states=false) ===\n"); - bench_arrow_api(false); - - println!("\n=== Arrow API (MoveExtractor, store_board_states=true) ===\n"); - bench_arrow_api(true); - - println!("\n=== Flat API (FlatBuffers) ===\n"); - bench_flat_api(); + println!("=== Parse API (Buffers) ===\n"); + bench_parse_api(); } diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index 0441a0d..b45543b 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -4,48 +4,6 @@ import numpy as np import numpy.typing as npt from numpy.typing import NDArray -class PyUciMove: - from_square: int - to_square: int - promotion: Optional[int] - - def __init__( - self, from_square: int, to_square: int, promotion: Optional[int] - ) -> None: ... - @property - def get_from_square_name(self) -> str: ... - @property - def get_to_square_name(self) -> str: ... - @property - def get_promotion_name(self) -> Optional[str]: ... - def __str__(self) -> str: ... - def __repr__(self) -> str: ... - -class PositionStatus: - is_checkmate: bool - is_stalemate: bool - legal_move_count: int - is_game_over: bool - insufficient_material: Tuple[bool, bool] - turn: bool - -class MoveExtractor: - moves: List[PyUciMove] - valid_moves: bool - comments: List[Optional[str]] - evals: List[Optional[float]] - clock_times: List[Optional[Tuple[int, int, float]]] - outcome: Optional[str] - headers: List[Tuple[str, str]] - castling_rights: List[Optional[Tuple[bool, bool, bool, bool]]] - position_status: Optional[PositionStatus] - - def __init__(self, store_legal_moves: bool = False) -> None: ... - def turn(self) -> bool: ... - def update_position_status(self) -> None: ... - @property - def legal_moves(self) -> List[List[PyUciMove]]: ... - class PyGameView: """Zero-copy view into a single game's data within a ParsedGames result. @@ -135,6 +93,11 @@ class PyGameView: """Raw PGN headers as dict.""" ... + @property + def outcome(self) -> Optional[str]: + """Game outcome from movetext: 'White', 'Black', 'Draw', 'Unknown', or None.""" + ... + @property def is_checkmate(self) -> bool: """Final position is checkmate.""" @@ -160,6 +123,22 @@ class PyGameView: """Whether game parsed successfully.""" ... + @property + def is_game_over(self) -> bool: + """Whether the game is over (checkmate, stalemate, or both sides insufficient).""" + ... + + @property + def comments(self) -> List[Optional[str]]: + """Raw text comments per move (only populated when store_comments=True).""" + ... + + @property + def legal_moves(self) -> List[List[Tuple[int, int, int]]]: + """Legal moves at each position (only populated when store_legal_moves=True). + Each entry is a list of (from_square, to_square, promotion) tuples.""" + ... + # === Convenience methods === def move_uci(self, move_idx: int) -> str: @@ -232,6 +211,18 @@ class PyChunkView: def valid(self) -> NDArray[np.bool_]: ... @property def headers(self) -> List[Dict[str, str]]: ... + @property + def outcome(self) -> List[Optional[str]]: ... + @property + def comments(self) -> List[Optional[str]]: ... + @property + def legal_move_from_squares(self) -> NDArray[np.uint8]: ... + @property + def legal_move_to_squares(self) -> NDArray[np.uint8]: ... + @property + def legal_move_promotions(self) -> NDArray[np.int8]: ... + @property + def legal_move_offsets(self) -> NDArray[np.uint32]: ... def __repr__(self) -> str: ... class ParsedGames: @@ -327,19 +318,32 @@ class ParsedGames: """ ... -def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ... -def parse_games( - pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False -) -> List[MoveExtractor]: ... -def parse_game_moves_arrow_chunked_array( - pgn_chunked_array: pyarrow.ChunkedArray, - num_threads: Optional[int] = None, +def parse_game( + pgn: str, + store_comments: bool = False, store_legal_moves: bool = False, -) -> List[MoveExtractor]: ... -def parse_games_flat( +) -> ParsedGames: + """Parse a single PGN game string. + + Convenience wrapper for parsing a single game. Returns a ParsedGames + container with one game. + + Args: + pgn: PGN game string + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) + + Returns: + ParsedGames object containing the parsed game + """ + ... + +def parse_games( pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None, chunk_multiplier: Optional[int] = None, + store_comments: bool = False, + store_legal_moves: bool = False, ) -> ParsedGames: """Parse chess games from a PyArrow ChunkedArray into flat NumPy arrays. @@ -349,6 +353,30 @@ def parse_games_flat( Args: pgn_chunked_array: PyArrow ChunkedArray containing PGN strings num_threads: Number of threads for parallel parsing (default: all CPUs) + chunk_multiplier: Multiplier for number of chunks (default: 1) + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) + + Returns: + ParsedGames object containing flat arrays and iteration support + """ + ... + +def parse_games_from_strings( + pgns: List[str], + num_threads: Optional[int] = None, + store_comments: bool = False, + store_legal_moves: bool = False, +) -> ParsedGames: + """Parse multiple PGN game strings in parallel. + + Convenience wrapper for when you have a list of strings rather than an Arrow array. + + Args: + pgns: List of PGN game strings + num_threads: Number of threads for parallel parsing (default: all CPUs) + store_comments: Whether to store raw text comments (default: False) + store_legal_moves: Whether to store legal moves at each position (default: False) Returns: ParsedGames object containing flat arrays and iteration support diff --git a/src/bench_parquet.py b/src/bench_parquet.py index f1b01d1..408ebba 100644 --- a/src/bench_parquet.py +++ b/src/bench_parquet.py @@ -1,21 +1,19 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime file_path = "2013-07-train-00000-of-00001.parquet" - pf = pq.ParquetFile(file_path) pylist = pf.read(columns=["movetext"]).column("movetext").to_pylist() a = datetime.now() for row in pylist: - extractor = rust_pgn_reader_python_binding.parse_game(row) - moves = extractor.moves - comments = extractor.comments + result = rust_pgn_reader_python_binding.parse_game(row) + game = result[0] + moves = game.moves_uci() b = datetime.now() print(b - a) diff --git a/src/bench_parquet_arrow.py b/src/bench_parquet_arrow.py index 59a8caf..dff4c94 100644 --- a/src/bench_parquet_arrow.py +++ b/src/bench_parquet_arrow.py @@ -1,20 +1,16 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime -file_path = "2013-07-train-00000-of-00001.parquet" +file_path = "2013-07-train-00000-of-00001.parquet" pf = pq.ParquetFile(file_path) - movetext_arrow_array = pf.read(columns=["movetext"]).column("movetext") a = datetime.now() -extractors = rust_pgn_reader_python_binding.parse_game_moves_arrow_chunked_array( - movetext_arrow_array -) +result = rust_pgn_reader_python_binding.parse_games(movetext_arrow_array) b = datetime.now() print(b - a) diff --git a/src/bench_parquet_parallel.py b/src/bench_parquet_parallel.py index 18e504a..f4b8949 100644 --- a/src/bench_parquet_parallel.py +++ b/src/bench_parquet_parallel.py @@ -1,18 +1,16 @@ import rust_pgn_reader_python_binding import pyarrow.parquet as pq - from datetime import datetime file_path = "2013-07-train-00000-of-00001.parquet" - pf = pq.ParquetFile(file_path) pylist = pf.read(columns=["movetext"]).column("movetext").to_pylist() a = datetime.now() -extractors = rust_pgn_reader_python_binding.parse_games(pylist) +result = rust_pgn_reader_python_binding.parse_games_from_strings(pylist) b = datetime.now() print(b - a) diff --git a/src/bench_parse_games_flat.py b/src/bench_parse_games.py similarity index 64% rename from src/bench_parse_games_flat.py rename to src/bench_parse_games.py index 65bf7ab..eecaa5e 100644 --- a/src/bench_parse_games_flat.py +++ b/src/bench_parse_games.py @@ -1,5 +1,5 @@ """ -Benchmark comparing parse_games_flat() vs parse_game_moves_arrow_chunked_array(). +Benchmark for parse_games() PGN parsing. This benchmark measures: 1. Parsing speed (games/second) @@ -7,7 +7,7 @@ 3. Data access patterns for ML workloads Usage: - python bench_parse_games_flat.py [parquet_file] + python bench_parse_games.py [parquet_file] If no parquet file is provided, synthetic PGN data will be generated. """ @@ -25,7 +25,6 @@ def generate_synthetic_pgns(num_games: int, moves_per_game: int = 40) -> list[str]: """Generate synthetic PGN games for benchmarking.""" - # A realistic game template move_pairs = [ ("e4", "e5"), ("Nf3", "Nc6"), @@ -51,7 +50,6 @@ def generate_synthetic_pgns(num_games: int, moves_per_game: int = 40) -> list[st pgns = [] for i in range(num_games): - # Build movetext moves = [] num_pairs = min(moves_per_game // 2, len(move_pairs)) for j in range(num_pairs): @@ -78,7 +76,6 @@ def load_parquet_pgns(file_path: str, limit: Optional[int] = None) -> pa.Chunked pf = pq.ParquetFile(file_path) - # Try common column names for PGN/movetext table = pf.read() for col_name in ["movetext", "pgn", "moves", "game"]: if col_name in table.column_names: @@ -92,21 +89,21 @@ def load_parquet_pgns(file_path: str, limit: Optional[int] = None) -> pa.Chunked ) -def benchmark_parse_games_flat( +def benchmark_parse_games( chunked_array: pa.ChunkedArray, num_threads: Optional[int] = None, warmup: int = 1 ) -> dict: - """Benchmark parse_games_flat().""" + """Benchmark parse_games().""" # Warmup for _ in range(warmup): - _ = pgn.parse_games_flat(chunked_array, num_threads=num_threads) + _ = pgn.parse_games(chunked_array, num_threads=num_threads) # Timed run start = time.perf_counter() - result = pgn.parse_games_flat(chunked_array, num_threads=num_threads) + result = pgn.parse_games(chunked_array, num_threads=num_threads) elapsed = time.perf_counter() - start return { - "method": "parse_games_flat", + "method": "parse_games", "elapsed_seconds": elapsed, "num_games": result.num_games, "num_moves": result.num_moves, @@ -119,44 +116,8 @@ def benchmark_parse_games_flat( } -def benchmark_parse_arrow_chunked( - chunked_array: pa.ChunkedArray, num_threads: Optional[int] = None, warmup: int = 1 -) -> dict: - """Benchmark parse_game_moves_arrow_chunked_array().""" - # Warmup - for _ in range(warmup): - _ = pgn.parse_game_moves_arrow_chunked_array( - chunked_array, num_threads=num_threads - ) - - # Timed run - start = time.perf_counter() - extractors = pgn.parse_game_moves_arrow_chunked_array( - chunked_array, num_threads=num_threads - ) - elapsed = time.perf_counter() - start - - num_games = len(extractors) - num_moves = sum(len(e.moves) for e in extractors) - num_positions = num_moves + num_games # Approximate - valid_games = sum(1 for e in extractors if e.valid_moves) - - return { - "method": "parse_arrow_chunked", - "elapsed_seconds": elapsed, - "num_games": num_games, - "num_moves": num_moves, - "num_positions": num_positions, - "games_per_second": num_games / elapsed, - "moves_per_second": num_moves / elapsed, - "positions_per_second": num_positions / elapsed, - "valid_games": valid_games, - "result": extractors, - } - - -def benchmark_data_access_flat(result) -> dict: - """Benchmark data access patterns for parse_games_flat result.""" +def benchmark_data_access(result) -> dict: + """Benchmark data access patterns for parse_games result.""" start = time.perf_counter() # Simulate ML data loading: access all boards via chunks @@ -172,25 +133,10 @@ def benchmark_data_access_flat(result) -> dict: return {"access_time": elapsed} -def benchmark_data_access_extractors(extractors: list) -> dict: - """Benchmark data access patterns for list of MoveExtractors.""" - start = time.perf_counter() - - # Simulate accessing all moves (requires iteration) - total = 0 - for e in extractors: - for m in e.moves: - total += m.from_square + m.to_square - - elapsed = time.perf_counter() - start - return {"access_time": elapsed} - - -def benchmark_per_game_access_flat(result) -> dict: - """Benchmark per-game access pattern for parse_games_flat result.""" +def benchmark_per_game_access(result) -> dict: + """Benchmark per-game access pattern.""" start = time.perf_counter() - # Per-game access pattern (iterate and access boards) for i in range(min(1000, result.num_games)): game = result[i] _ = game.boards @@ -203,10 +149,9 @@ def benchmark_per_game_access_flat(result) -> dict: def benchmark_position_to_game_mapping(result) -> dict: - """Benchmark position-to-game mapping for parse_games_flat result.""" + """Benchmark position-to-game mapping.""" start = time.perf_counter() - # Position-to-game mapping (still works globally) indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) _ = result.position_to_game(indices) @@ -242,9 +187,7 @@ def print_results(results: dict, label: str): def main(): - parser = argparse.ArgumentParser( - description="Benchmark parse_games_flat() vs parse_game_moves_arrow_chunked_array()" - ) + parser = argparse.ArgumentParser(description="Benchmark parse_games()") parser.add_argument( "parquet_file", nargs="?", @@ -282,7 +225,7 @@ def main(): args = parser.parse_args() print("=" * 60) - print("parse_games_flat() Benchmark") + print("parse_games() Benchmark") print("=" * 60) # Load or generate data @@ -303,30 +246,15 @@ def main(): print(f"Threads: {args.threads or 'all cores'}") print(f"Warmup iterations: {args.warmup}") - # Benchmark parse_games_flat + # Benchmark print("\n" + "=" * 60) - print("PARSING BENCHMARKS") + print("PARSING BENCHMARK") print("=" * 60) - flat_results = benchmark_parse_games_flat( + results = benchmark_parse_games( chunked_array, num_threads=args.threads, warmup=args.warmup ) - print_results(flat_results, "parse_games_flat()") - - # Benchmark parse_game_moves_arrow_chunked_array - arrow_results = benchmark_parse_arrow_chunked( - chunked_array, num_threads=args.threads, warmup=args.warmup - ) - print_results(arrow_results, "parse_game_moves_arrow_chunked_array()") - - # Comparison - print("\n" + "=" * 60) - print("COMPARISON") - print("=" * 60) - speedup = arrow_results["elapsed_seconds"] / flat_results["elapsed_seconds"] - print( - f"\nparse_games_flat() is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'}" - ) + print_results(results, "parse_games()") # Data access benchmarks if not args.skip_access: @@ -334,34 +262,23 @@ def main(): print("DATA ACCESS BENCHMARKS") print("=" * 60) - flat_access = benchmark_data_access_flat(flat_results["result"]) - print(f"\nFlat arrays access time: {flat_access['access_time']:.3f}s") - - extractor_access = benchmark_data_access_extractors(arrow_results["result"]) - print(f"Extractor list access time: {extractor_access['access_time']:.3f}s") - - access_speedup = extractor_access["access_time"] / flat_access["access_time"] - print(f"\nFlat arrays are {access_speedup:.2f}x faster for data access") + data_access = benchmark_data_access(results["result"]) + print(f"\nArray access time: {data_access['access_time']:.3f}s") - # Per-game access benchmarks - print("\n" + "=" * 60) - print("PER-GAME ACCESS BENCHMARKS") - print("=" * 60) - - per_game_access = benchmark_per_game_access_flat(flat_results["result"]) + per_game_access = benchmark_per_game_access(results["result"]) print(f"\nPer-game access time: {per_game_access['access_time']:.3f}s") print(f"Games accessed: {per_game_access['games_accessed']:,}") print( f"Time per game: {per_game_access['access_time'] / per_game_access['games_accessed'] * 1000:.3f}ms" ) - position_mapping = benchmark_position_to_game_mapping(flat_results["result"]) + position_mapping = benchmark_position_to_game_mapping(results["result"]) print(f"\nPosition-to-game mapping: {position_mapping['access_time']:.3f}s") print( f"Positions mapped: {position_mapping['positions_accessed']:,}" ) print( - f"Time per lookup: {position_mapping['access_time'] / position_mapping['positions_accessed'] * 1000000:.3f}μs" + f"Time per lookup: {position_mapping['access_time'] / position_mapping['positions_accessed'] * 1000000:.3f}us" ) # Memory usage (approximate) @@ -369,10 +286,10 @@ def main(): print("MEMORY USAGE (approximate)") print("=" * 60) - flat_result = flat_results["result"] - flat_bytes = 0 - for chunk in flat_result.chunks: - flat_bytes += ( + result = results["result"] + total_bytes = 0 + for chunk in result.chunks: + total_bytes += ( chunk.boards.nbytes + chunk.castling.nbytes + chunk.en_passant.nbytes @@ -387,10 +304,10 @@ def main(): + chunk.position_offsets.nbytes ) - print(f"\nFlat arrays total: {flat_bytes / 1024 / 1024:.2f} MB") - print(f"Bytes per position: {flat_bytes / flat_result.num_positions:.1f}") - print(f"Bytes per move: {flat_bytes / flat_result.num_moves:.1f}") - print(f"Number of chunks: {flat_result.num_chunks}") + print(f"\nArrays total: {total_bytes / 1024 / 1024:.2f} MB") + print(f"Bytes per position: {total_bytes / result.num_positions:.1f}") + print(f"Bytes per move: {total_bytes / result.num_moves:.1f}") + print(f"Number of chunks: {result.num_chunks}") print("\n" + "=" * 60) print("Benchmark complete!") diff --git a/src/bench.py b/src/bench_pgn_file.py similarity index 61% rename from src/bench.py rename to src/bench_pgn_file.py index 28978b3..e5b91f7 100644 --- a/src/bench.py +++ b/src/bench_pgn_file.py @@ -1,5 +1,4 @@ import rust_pgn_reader_python_binding - from datetime import datetime @@ -8,12 +7,12 @@ def split_pgn(file_path): with open(file_path, "r", encoding="utf-8") as file: game_lines = [] for line in file: - if line.strip() == "" and game_lines: # End of a game + if line.strip() == "" and game_lines: yield "".join(game_lines) - game_lines = [] # Reset for the next game + game_lines = [] else: game_lines.append(line) - if game_lines: # Yield the last game if the file doesn't end with a blank line + if game_lines: yield "".join(game_lines) @@ -21,9 +20,9 @@ def split_pgn(file_path): a = datetime.now() for game_pgn in split_pgn(file_path): - extractor = rust_pgn_reader_python_binding.parse_game(game_pgn) - moves = extractor.moves - comments = extractor.comments + result = rust_pgn_reader_python_binding.parse_game(game_pgn) + game = result[0] + moves = game.moves_uci() b = datetime.now() print(b - a) diff --git a/src/example_chess960.py b/src/example_chess960.py index 5105556..c7c0290 100644 --- a/src/example_chess960.py +++ b/src/example_chess960.py @@ -14,15 +14,11 @@ 1.g3 d5 2.d4 g6 3.b3 Nf6 4.Ne3 b6 5.Nh3 Ne6 6.f4 Ng7 7.g4 h6 8.Nf2 Ne6 9.f5 Nf4 10.Nf1 gxf5 11.Qd2 Ne6 12.gxf5 Ng5 13.Ng3 Qd7 14.Rg1 Rg8 15.O-O-O O-O-O 16.Kb1 Kb8 17.Bb2 h5 18.Qf4 Bb7 19.h4 Nge4 20.Nfxe4 dxe4 21.Nxe4 Nxe4 22.Bxe4 Bxe4 23.Qxe4 e6 24.Rxg8 Rxg8 25.fxe6 fxe6 26.Rf1 Qe7 27.e3 Bf6 28.Ba3 Qd8 29.Qxe6 Bxh4 30.Rf5 Bg5 31.d5 Qc8 32.e4 h4 33.Qf7 Be3 34.d6 Rg1+ 35.Rf1 Rxf1+ 36.Qxf1 h3 37.dxc7+ Kxc7 38.Qc4+ Kb7 39.Qf7+ Qc7 40.Qd5+ Qc6 41.Qf7+ Qc7 42.Qd5+ Qc6 43.Qf7+ Qc7 1/2-1/2 {OL: 0} """ -extractor = rust_pgn_reader_python_binding.parse_game(pgn_moves) +result = rust_pgn_reader_python_binding.parse_game(pgn_moves) +game = result[0] -print("moves", extractor.moves) -print("comments", extractor.comments) -print("valid", extractor.valid_moves) -# print(extractor.evals) -print("clock", extractor.clock_times) -# print(extractor.outcome) -# print(extractor.position_status.is_checkmate) -# print(extractor.position_status.is_stalemate) -# print(extractor.position_status.is_game_over) -# print(extractor.position_status.legal_move_count) +print("Moves (UCI):", game.moves_uci()) +print("Valid:", game.is_valid) +print("Outcome:", game.outcome) +print("Num positions:", game.num_positions) +print("Clocks (first 5):", game.clocks[:5].tolist()) diff --git a/src/example_complex.py b/src/example_complex.py new file mode 100644 index 0000000..f56fb2a --- /dev/null +++ b/src/example_complex.py @@ -0,0 +1,41 @@ +import rust_pgn_reader_python_binding +import numpy as np + +pgn_moves = """ +[Event "Casual Correspondence game"] +[Site "https:///gdEj47Dv"] +[Date ""] +[White "lichess AI level 8"] +[Black "TherealARB"] +[Result "0-1"] +[UTCDate ""] +[UTCTime ":15"] +[WhiteElo "?"] +[BlackElo "1500"] +[Variant "Standard"] +[TimeControl "-"] +[ECO "C00"] +[Opening "Rat Defense: Small Center Defense"] +[Termination "Normal"] +[Annotator ""] + +1. e4 { [%eval ] } 1... e6 { [%eval ] } +2. d4 { [%eval ] } 2... d6?! { (0 → ) Inaccuracy. d5 was best. } { [%eval ] } { C00 Rat Defense: Small Center Defense } (2... d5 3. Nc3 Nf6 4. e5 Nfd7 5. f4 c5 6. Nf3 Nc6 7. Be3) +3. c4?! { ( → ) Inaccuracy. Bd3 was best. } { [%eval ] } (3. Bd3 c5 4. dxc5 dxc5 5. Nf3 Ne7 6. Qe2 Bd7 7. Nc3 Nec6) 3... h6? { ( → ) Mistake. d5 was best. } { [%eval ] } (3... d5 4. cxd5 exd5 5. exd5 Nf6 6. Nc3 Be7 7. Nf3 O-O 8. Bc4) +4. Nf3 { [%eval ] } 4... a6 { [%eval ] } +5. Nc3 { [%eval 1] } 5... g6 { [%eval ] } 6. Be3 { [%eval ] } 6... b6 { [%eval ] } 7. Bd3 { [%eval ] } 7... Bg7 { [%eval ] } 8. Qd2 { [%eval ] } 8... Bb7 { [%eval 7] } 9. O-O { [%eval ] } 9... Ne7 { [%eval ] } 10. d5 { [%eval ] } 10... e5 { [%eval ] } 11. g3 { [%eval ] } 11... Nd7 { [%eval 4] } 12. a3 { [%eval ] } 12... g5 { [%eval ] } 13. h4 { [%eval 3] } 13... f6 { [%eval ] } 14. Rfc1 { [%eval ] } 14... Ng6 { [%eval ] } 15. h5 { [%eval ] } 15... Nf4 { [%eval ] } 16. gxf4 { [%eval ] } 16... gxf4 { [%eval ] } 17. Bxf4?! { ( → ) Inaccuracy. Kh1 was best. } { [%eval ] } (17. Kh1 fxe3 18. fxe3 f5 19. Rg1 Bf6 20. exf5 Bg5 21. Ne4 Nf6 22. Qg2 c6 23. Nxf6+ Qxf6) 17... exf4 { [%eval 1] } 18. Qxf4 { [%eval ] } 18... Ne5 { [%eval ] } 19. Be2?! { (0 → ) Inaccuracy. Rd1 was best. } { [%eval ] } (19. Rd1 Nxf3+ 20. Qxf3 O-O 21. Kf1 Bc8 22. Re1 Qe7 23. Re3 Qf7 24. Rae1 Bd7 25. Ne2 f5) 19... Nxf3+ { [%eval 1] } 20. Bxf3 { [%eval 5] } 20... Bc8 { [%eval ] } 21. Re1 { [%eval ] } 21... O-O { [%eval ] } 22. Kh2 { [%eval ] } 22... Rf7 { [%eval ] } 23. Rg1 { [%eval ] } 23... f5 { [%eval ] } 24. Rg6 { [%eval -3] } 24... Kh8 { [%eval ] } 25. Rxh6+ { [%eval ] } 25... Bxh6 { [%eval ] } 26. Qxh6+ { [%eval ] } 26... Rh7 { [%eval ] } 27. Qf4 { [%eval ] } 27... Qf6 { [%eval ] } 28. Re1 { [%eval -7] } 28... Bd7 { [%eval -] } 29. e5 { [%eval 8] } 29... Qg7 { [%eval -5] } 30. exd6 { [%eval ] } 30... Rg8 { [%eval ] } 31. Qg3 { [%eval -4] } 31... Qf6 { [%eval -] } 32. Qf4 { [%eval ] } 32... cxd6 { [%eval ] } 33. Ne2 { [%eval ] } 33... Qe5 { [%eval -5] } 34. b4?! { (-5 → ) Inaccuracy. Rg1 was best. } { [%eval ] } (34. Rg1 Qxf4+ 35. Nxf4 Rxg1 36. Kxg1 Kg7 37. Kf1 b5 38. cxb5 Bxb5+ 39. Ke1 Kf6 40. Kd2 a5) 34... Be8 { [%eval ] } 35. Qxe5+ { [%eval ] } 35... dxe5 { [%eval ] } 36. Ng3 { [%eval ] } 36... f4 { [%eval ] } 37. Rxe5 { [%eval ] } 37... fxg3+ { [%eval ] } 38. fxg3 { [%eval -] } 38... Rhg7 { [%eval ] } 39. g4 { [%eval ] } 39... Bd7 { [%eval ] } 40. h6 { [%eval ] } 40... Rg6 { [%eval ] } 41. g5?! ... +66. Kc2 { [%eval #-25] } 66... b4 { [%eval #-6] } +67. Kd2 { [%eval #-9] } 67... b3 { [%eval #-8] } 68. Bd3 { [%eval #-7] } 68... b2 { [%eval #-7] } 69. Ke3 { [%eval #-6] } 69... Ba2 { [%eval #-11] } 70. Kf2 { [%eval #-8] } 70... Bb1 { [%eval #-6] } 71. Bc4 { [%eval #-6] } 71... Bg6 { [%eval #-4] } 72. Bd3 { [%eval #-4] } 72... Bxd3 { [%eval #-3] } 73. Kf3 { [%eval #-3] } 73... b1=Q { [%eval #-2] } 74. Ke3 { [%eval #-2] } 74... Qf1 { [%eval #-1] } 75. Kd2 { [%eval #-1] } 75... Qae1# { Black wins by checkmate. } 0-1 +""" + +result = rust_pgn_reader_python_binding.parse_game(pgn_moves, store_comments=True) +game = result[0] + +print("Moves (UCI):", game.moves_uci()) +print("Comments (first 5):", game.comments[:5]) +print("Valid:", game.is_valid) +print("Evals (first 10):", game.evals[:10].tolist()) +print("Clocks (first 10):", game.clocks[:10].tolist()) +print("Outcome:", game.outcome) +print("Is checkmate:", game.is_checkmate) +print("Is game over:", game.is_game_over) diff --git a/src/example_parse_games_flat.py b/src/example_parse_games.py similarity index 95% rename from src/example_parse_games_flat.py rename to src/example_parse_games.py index f0848b1..7981306 100644 --- a/src/example_parse_games_flat.py +++ b/src/example_parse_games.py @@ -1,5 +1,5 @@ """ -Example demonstrating the parse_games_flat() API for ML-optimized PGN parsing. +Example demonstrating the parse_games() API for ML-optimized PGN parsing. This API returns flat NumPy arrays suitable for efficient batch processing in machine learning pipelines. @@ -52,14 +52,14 @@ def main(): print("=" * 60) - print("parse_games_flat() API Example") + print("parse_games() API Example") print("=" * 60) # Create PyArrow chunked array from PGN strings chunked_array = pa.chunked_array([pa.array(sample_pgns)]) # Parse all games at once - returns flat NumPy arrays - result = pgn.parse_games_flat(chunked_array) + result = pgn.parse_games(chunked_array) # === Basic Statistics === print(f"\n--- Basic Statistics ---") @@ -92,6 +92,7 @@ def main(): print(f" Positions: {game.num_positions}") print(f" Valid: {game.is_valid}") print(f" Checkmate: {game.is_checkmate}") + print(f" Outcome: {game.outcome}") print( f" UCI moves: {' '.join(game.moves_uci()[:10])}{'...' if len(game) > 10 else ''}" ) @@ -157,6 +158,7 @@ def main(): if game.is_checkmate: print(f"Game {i + 1} ended in checkmate!") print(f" Final position legal moves: {game.legal_move_count}") + print(f" Is game over: {game.is_game_over}") print("\n" + "=" * 60) print("Example complete!") diff --git a/src/example_rowgroup.py b/src/example_rowgroup.py index 9dddc4b..6acddaf 100644 --- a/src/example_rowgroup.py +++ b/src/example_rowgroup.py @@ -12,10 +12,11 @@ pf = pq.ParquetFile(file_path) for i in range(pf.num_row_groups): - table = pf.read_row_group(0, columns=["movetext"]) - extractors = rust_pgn_reader_python_binding.parse_games( - table.column("movetext").to_pylist(), num_threads=4 + table = pf.read_row_group(i, columns=["movetext"]) + result = rust_pgn_reader_python_binding.parse_games( + table.column("movetext"), num_threads=4 ) + print(f"Row group {i}: {result.num_games} games, {result.num_moves} moves") b = datetime.now() print(b - a) diff --git a/src/example_str.py b/src/example_str.py index 58f4ffd..a41bbd9 100644 --- a/src/example_str.py +++ b/src/example_str.py @@ -1,4 +1,5 @@ import rust_pgn_reader_python_binding +import numpy as np pgn_moves = """ 1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } @@ -16,16 +17,16 @@ 13. b3?? { [%eval -4.14] [%clk 0:00:02] } 13... Nf4? { [%eval -2.73] [%clk 0:00:21] } 0-1 """ -extractor = rust_pgn_reader_python_binding.parse_game(pgn_moves) +result = rust_pgn_reader_python_binding.parse_game(pgn_moves, store_comments=True) +game = result[0] -print(extractor.moves) -print(extractor.comments) -print(extractor.valid_moves) -print(extractor.evals) -print(extractor.clock_times) -print(extractor.outcome) -assert extractor.position_status is not None -print(extractor.position_status.is_checkmate) -print(extractor.position_status.is_stalemate) -print(extractor.position_status.is_game_over) -print(extractor.position_status.legal_move_count) +print("Moves (UCI):", game.moves_uci()) +print("Comments:", game.comments) +print("Valid:", game.is_valid) +print("Evals:", game.evals.tolist()) +print("Clocks:", game.clocks.tolist()) +print("Outcome:", game.outcome) +print("Is checkmate:", game.is_checkmate) +print("Is stalemate:", game.is_stalemate) +print("Is game over:", game.is_game_over) +print("Legal move count:", game.legal_move_count) diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 00f7026..823b352 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -1,26 +1,33 @@ -//! Flat buffer visitor for direct SoA output. +//! SoA (Struct-of-Arrays) visitor for PGN parsing. //! //! This module provides a memory-efficient parsing approach that writes -//! directly to flat buffers instead of allocating per-game Vec structures. -//! Used by `parse_games_flat` for optimal performance. +//! directly to shared buffers instead of allocating per-game Vec structures. +//! Used by `parse_games` for optimal performance. use crate::board_serialization::{ get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, }; use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; -use pgn_reader::{Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; +use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; use std::collections::HashMap; use std::ops::ControlFlow; -/// Accumulated flat buffers for multiple parsed games. +/// Configuration for what optional data to store during parsing. +#[derive(Clone, Debug)] +pub struct ParseConfig { + pub store_comments: bool, + pub store_legal_moves: bool, +} + +/// Accumulated buffers for multiple parsed games. /// /// This struct holds all data in a struct-of-arrays layout, optimized for: /// - Efficient thread-local accumulation during parallel parsing /// - Fast merging of thread-local buffers via `extend_from_slice` /// - Direct conversion to NumPy arrays without intermediate allocations #[derive(Default, Clone)] -pub struct FlatBuffers { +pub struct Buffers { // Board state arrays (one entry per position) pub boards: Vec, // Flattened: 64 bytes per position pub castling: Vec, // Flattened: 4 bools per position [K,Q,k,q] @@ -44,19 +51,35 @@ pub struct FlatBuffers { pub legal_move_count: Vec, pub valid: Vec, pub headers: Vec>, + pub outcome: Vec>, // "White", "Black", "Draw", "Unknown", or None + + // Optional: raw text comments (per-move), only populated when store_comments=true + pub comments: Vec>, + + // Optional: legal moves at each position, only populated when store_legal_moves=true + // Stored as flat arrays with CSR-style offsets + pub legal_move_from_squares: Vec, + pub legal_move_to_squares: Vec, + pub legal_move_promotions: Vec, + pub legal_move_counts: Vec, // Number of legal moves per position } -impl FlatBuffers { - /// Create a new FlatBuffers with pre-allocated capacity. +impl Buffers { + /// Create a new Buffers with pre-allocated capacity. /// /// # Arguments /// * `estimated_games` - Expected number of games /// * `moves_per_game` - Expected average moves per game (default: 70) - pub fn with_capacity(estimated_games: usize, moves_per_game: usize) -> Self { + /// * `config` - Configuration for optional features + pub fn with_capacity( + estimated_games: usize, + moves_per_game: usize, + config: &ParseConfig, + ) -> Self { let estimated_moves = estimated_games * moves_per_game; let estimated_positions = estimated_moves + estimated_games; // +1 initial position per game - FlatBuffers { + Buffers { // Board state arrays boards: Vec::with_capacity(estimated_positions * 64), castling: Vec::with_capacity(estimated_positions * 4), @@ -80,6 +103,36 @@ impl FlatBuffers { legal_move_count: Vec::with_capacity(estimated_games), valid: Vec::with_capacity(estimated_games), headers: Vec::with_capacity(estimated_games), + outcome: Vec::with_capacity(estimated_games), + + // Optional comments + comments: if config.store_comments { + Vec::with_capacity(estimated_moves) + } else { + Vec::new() + }, + + // Optional legal moves + legal_move_from_squares: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_to_squares: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_promotions: if config.store_legal_moves { + Vec::with_capacity(estimated_positions * 30) + } else { + Vec::new() + }, + legal_move_counts: if config.store_legal_moves { + Vec::with_capacity(estimated_positions) + } else { + Vec::new() + }, } } @@ -118,35 +171,54 @@ impl FlatBuffers { } offsets } + + /// Compute CSR-style offsets from legal move counts (per position). + pub fn compute_legal_move_offsets(&self) -> Vec { + let mut offsets = Vec::with_capacity(self.legal_move_counts.len() + 1); + offsets.push(0); + for &count in &self.legal_move_counts { + offsets.push(offsets.last().unwrap() + count); + } + offsets + } + + /// Total number of legal moves stored across all positions. + pub fn total_legal_moves(&self) -> usize { + self.legal_move_from_squares.len() + } } -/// Visitor that writes directly to FlatBuffers. +/// Visitor that writes directly to shared Buffers. /// /// This visitor does not allocate any per-game Vec structures. -/// All data is appended directly to the shared FlatBuffers. -pub struct FlatVisitor<'a> { - buffers: &'a mut FlatBuffers, +/// All data is appended directly to the shared Buffers. +pub struct GameVisitor<'a> { + buffers: &'a mut Buffers, + config: ParseConfig, pos: Chess, valid_moves: bool, current_headers: Vec<(String, String)>, + current_outcome: Option, // Track counts for current game current_move_count: u32, current_position_count: u32, } -impl<'a> FlatVisitor<'a> { - pub fn new(buffers: &'a mut FlatBuffers) -> Self { - FlatVisitor { +impl<'a> GameVisitor<'a> { + pub fn new(buffers: &'a mut Buffers, config: &ParseConfig) -> Self { + GameVisitor { buffers, + config: config.clone(), pos: Chess::default(), valid_moves: true, current_headers: Vec::with_capacity(10), + current_outcome: None, current_move_count: 0, current_position_count: 0, } } - /// Record current board state to flat buffers. + /// Record current board state to buffers. fn push_board_state(&mut self) { self.buffers .boards @@ -159,9 +231,37 @@ impl<'a> FlatVisitor<'a> { .push(get_halfmove_clock(&self.pos)); self.buffers.turn.push(get_turn(&self.pos)); self.current_position_count += 1; + + // Store legal moves if enabled + if self.config.store_legal_moves { + self.push_legal_moves(); + } } - /// Record move data to flat buffers. + /// Record legal moves at current position to buffers. + fn push_legal_moves(&mut self) { + let legal_moves = self.pos.legal_moves(); + let mut count: u32 = 0; + for m in legal_moves { + let uci_move_obj = UciMove::from_standard(m); + if let UciMove::Normal { + from, + to, + promotion, + } = uci_move_obj + { + self.buffers.legal_move_from_squares.push(from as u8); + self.buffers.legal_move_to_squares.push(to as u8); + self.buffers + .legal_move_promotions + .push(promotion.map(|p| p as i8).unwrap_or(-1)); + count += 1; + } + } + self.buffers.legal_move_counts.push(count); + } + + /// Record move data to buffers. fn push_move(&mut self, from: u8, to: u8, promotion: Option) { self.buffers.from_squares.push(from); self.buffers.to_squares.push(to); @@ -171,6 +271,10 @@ impl<'a> FlatVisitor<'a> { // Push placeholders for clock and eval (will be overwritten by comment()) self.buffers.clocks.push(f32::NAN); self.buffers.evals.push(f32::NAN); + // Push comment placeholder if enabled (will be overwritten by comment()) + if self.config.store_comments { + self.buffers.comments.push(None); + } self.current_move_count += 1; } @@ -196,6 +300,7 @@ impl<'a> FlatVisitor<'a> { .position_counts .push(self.current_position_count); self.buffers.valid.push(self.valid_moves); + self.buffers.outcome.push(self.current_outcome.take()); // Convert headers to HashMap let header_map: HashMap = self.current_headers.drain(..).collect(); @@ -203,7 +308,7 @@ impl<'a> FlatVisitor<'a> { } } -impl Visitor for FlatVisitor<'_> { +impl Visitor for GameVisitor<'_> { type Tags = Vec<(String, String)>; type Movetext = (); type Output = bool; @@ -228,6 +333,7 @@ impl Visitor for FlatVisitor<'_> { fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { self.current_headers = tags; self.valid_moves = true; + self.current_outcome = None; self.current_move_count = 0; self.current_position_count = 0; @@ -324,6 +430,12 @@ impl Visitor for FlatVisitor<'_> { comment: RawComment<'_>, ) -> ControlFlow { if let Ok((_, parsed_comments)) = parse_comments(comment.as_bytes()) { + let mut move_comments = if self.config.store_comments { + Some(String::new()) + } else { + None + }; + for content in parsed_comments { match content { CommentContent::Tag(tag_content) => match tag_content { @@ -344,15 +456,35 @@ impl Visitor for FlatVisitor<'_> { hours as f32 * 3600.0 + minutes as f32 * 60.0 + seconds as f32; } } - ParsedTag::Mate(_) => { - // Mate scores are handled as comments, not numeric evals + ParsedTag::Mate(mate_value) => { + // Mate scores stored as text in comments (matching old API behavior) + if let Some(ref mut comments) = move_comments { + if !comments.is_empty() && !comments.ends_with(' ') { + comments.push(' '); + } + comments.push_str(&format!("[Mate {}]", mate_value)); + } } }, - CommentContent::Text(_) => { - // Text comments are not stored in flat output + CommentContent::Text(text) => { + if let Some(ref mut comments) = move_comments { + if !text.trim().is_empty() { + if !comments.is_empty() { + comments.push(' '); + } + comments.push_str(&text); + } + } } } } + + // Update the last comment entry if comments are enabled + if let Some(comment_text) = move_comments { + if let Some(last_comment) = self.buffers.comments.last_mut() { + *last_comment = Some(comment_text); + } + } } ControlFlow::Continue(()) } @@ -369,6 +501,13 @@ impl Visitor for FlatVisitor<'_> { _movetext: &mut Self::Movetext, _outcome: Outcome, ) -> ControlFlow { + self.current_outcome = Some(match _outcome { + Outcome::Known(known) => match known { + KnownOutcome::Decisive { winner } => format!("{:?}", winner), + KnownOutcome::Draw => "Draw".to_string(), + }, + Outcome::Unknown => "Unknown".to_string(), + }); self.update_position_status(); ControlFlow::Continue(()) } @@ -383,13 +522,17 @@ impl Visitor for FlatVisitor<'_> { } } -/// Parse a single game directly into FlatBuffers. -pub fn parse_game_to_flat(pgn: &str, buffers: &mut FlatBuffers) -> Result { +/// Parse a single game directly into Buffers. +pub fn parse_game_to_buffers( + pgn: &str, + buffers: &mut Buffers, + config: &ParseConfig, +) -> Result { use pgn_reader::Reader; use std::io::Cursor; let mut reader = Reader::new(Cursor::new(pgn)); - let mut visitor = FlatVisitor::new(buffers); + let mut visitor = GameVisitor::new(buffers, config); match reader.read_game(&mut visitor) { Ok(Some(valid)) => Ok(valid), @@ -402,6 +545,13 @@ pub fn parse_game_to_flat(pgn: &str, buffers: &mut FlatBuffers) -> Result ParseConfig { + ParseConfig { + store_comments: false, + store_legal_moves: false, + } + } + #[test] fn test_parse_simple_game() { let pgn = r#"[Event "Test"] @@ -411,8 +561,9 @@ mod tests { 1. e4 e5 2. Nf3 Nc6 1-0"#; - let mut buffers = FlatBuffers::with_capacity(1, 70); - let result = parse_game_to_flat(pgn, &mut buffers); + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); assert!(result.is_ok()); assert!(result.unwrap()); // valid game @@ -421,6 +572,7 @@ mod tests { assert_eq!(buffers.position_counts[0], 5); // 5 positions (initial + 4 moves) assert_eq!(buffers.total_moves(), 4); assert_eq!(buffers.total_positions(), 5); + assert_eq!(buffers.outcome[0], Some("White".to_string())); } #[test] @@ -430,8 +582,9 @@ mod tests { 1. e4 { [%eval 0.17] [%clk 0:03:00] } 1... e5 { [%eval 0.19] [%clk 0:02:58] } 1-0"#; - let mut buffers = FlatBuffers::with_capacity(1, 70); - let result = parse_game_to_flat(pgn, &mut buffers); + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); assert!(result.is_ok()); assert_eq!(buffers.total_moves(), 2); @@ -459,19 +612,22 @@ mod tests { 1. d4 d5 2. c4 0-1"#; - let mut buffers = FlatBuffers::with_capacity(2, 70); + let config = default_config(); + let mut buffers = Buffers::with_capacity(2, 70, &config); - parse_game_to_flat(pgn1, &mut buffers).unwrap(); - parse_game_to_flat(pgn2, &mut buffers).unwrap(); + parse_game_to_buffers(pgn1, &mut buffers, &config).unwrap(); + parse_game_to_buffers(pgn2, &mut buffers, &config).unwrap(); assert_eq!(buffers.num_games(), 2); assert_eq!(buffers.total_moves(), 5); // 2 + 3 moves assert_eq!(buffers.move_counts, vec![2, 3]); + assert_eq!(buffers.outcome[0], Some("White".to_string())); + assert_eq!(buffers.outcome[1], Some("Black".to_string())); } #[test] fn test_compute_offsets() { - let mut buffers = FlatBuffers::default(); + let mut buffers = Buffers::default(); buffers.move_counts = vec![4, 6, 3]; buffers.position_counts = vec![5, 7, 4]; @@ -481,4 +637,89 @@ mod tests { let pos_offsets = buffers.compute_position_offsets(); assert_eq!(pos_offsets, vec![0, 5, 12, 16]); } + + #[test] + fn test_outcome_without_headers() { + // PGN without Result header - outcome comes from movetext + let pgn = "1. e4 e5 2. Nf3 Nc6 0-1"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.outcome[0], Some("Black".to_string())); + } + + #[test] + fn test_outcome_draw() { + let pgn = "1. e4 e5 1/2-1/2"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.outcome[0], Some("Draw".to_string())); + } + + #[test] + fn test_comments_disabled() { + let pgn = r#"1. e4 { a comment } e5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert!(buffers.comments.is_empty()); + } + + #[test] + fn test_comments_enabled() { + let pgn = r#"1. e4 { a comment } 1... e5 { [%eval 0.19] } 1-0"#; + + let config = ParseConfig { + store_comments: true, + store_legal_moves: false, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert_eq!(buffers.comments.len(), 2); + // Raw text from PGN includes surrounding spaces from the parser + assert_eq!(buffers.comments[0], Some(" a comment ".to_string())); + // The second comment only has an eval tag, so text portion is empty + assert_eq!(buffers.comments[1], Some("".to_string())); + } + + #[test] + fn test_legal_moves_disabled() { + let pgn = "1. e4 e5 1-0"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + assert!(buffers.legal_move_from_squares.is_empty()); + assert!(buffers.legal_move_counts.is_empty()); + } + + #[test] + fn test_legal_moves_enabled() { + let pgn = "1. e4 1-0"; + + let config = ParseConfig { + store_comments: false, + store_legal_moves: true, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config).unwrap(); + + // 2 positions: initial + after e4 + assert_eq!(buffers.legal_move_counts.len(), 2); + // Initial position has 20 legal moves + assert_eq!(buffers.legal_move_counts[0], 20); + // After e4, black has 20 legal moves + assert_eq!(buffers.legal_move_counts[1], 20); + // Total legal moves stored + assert_eq!(buffers.legal_move_from_squares.len(), 40); + } } diff --git a/src/lib.rs b/src/lib.rs index 48961a9..97d328d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,17 @@ use arrow_array::{Array, LargeStringArray, StringArray}; use numpy::{PyArray1, PyArrayMethods}; -use pgn_reader::Reader; use pyo3::prelude::*; use pyo3_arrow::PyChunkedArray; -use rayon::prelude::*; use rayon::ThreadPoolBuilder; - -use std::io::Cursor; +use rayon::prelude::*; mod board_serialization; mod comment_parsing; mod flat_visitor; mod python_bindings; -mod visitor; -pub use flat_visitor::{parse_game_to_flat, FlatBuffers}; -use python_bindings::{ - ChunkData, ParsedGames, ParsedGamesIter, PositionStatus, PyChunkView, PyGameView, PyUciMove, -}; -pub use visitor::MoveExtractor; +pub use flat_visitor::{Buffers, ParseConfig, parse_game_to_buffers}; +use python_bindings::{ChunkData, ParsedGames, ParsedGamesIter, PyChunkView, PyGameView}; /// Parse games from Arrow chunked array into a chunked ParsedGames container. /// @@ -26,17 +19,23 @@ pub use visitor::MoveExtractor; /// (num_chunks = num_threads * chunk_multiplier) to avoid the allocation storm /// caused by Rayon's dynamic work-stealing with fold_with. /// -/// Each chunk gets exactly one FlatBuffers instance. Instead of merging all +/// Each chunk gets exactly one Buffers instance. Instead of merging all /// chunks into a single buffer (which was memory-bandwidth-bound), we keep /// the per-thread buffers and provide virtual indexing across them. #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None, chunk_multiplier=None))] -fn parse_games_flat( +#[pyo3(signature = (pgn_chunked_array, num_threads=None, chunk_multiplier=None, store_comments=false, store_legal_moves=false))] +fn parse_games( py: Python<'_>, pgn_chunked_array: PyChunkedArray, num_threads: Option, chunk_multiplier: Option, + store_comments: bool, + store_legal_moves: bool, ) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; let num_threads = num_threads.unwrap_or_else(num_cpus::get); // Default multiplier of 1 means exactly num_threads chunks (one per thread). // Higher values (e.g., 4) create more chunks for better load balancing @@ -73,7 +72,7 @@ fn parse_games_flat( let n_games = pgn_str_slices.len(); if n_games == 0 { - let empty_chunk = flat_buffers_to_chunk_data(py, FlatBuffers::default())?; + let empty_chunk = buffers_to_chunk_data(py, Buffers::default())?; return build_parsed_games(py, vec![empty_chunk]); } @@ -99,39 +98,46 @@ fn parse_games_flat( let moves_per_game = 70; // Parse in parallel using par_chunks for explicit, fixed-size chunking. - // This creates exactly ceil(n_games / chunk_size) FlatBuffers instances, + // This creates exactly ceil(n_games / chunk_size) Buffers instances, // avoiding the allocation storm from Rayon's dynamic work-stealing. - let chunk_results: Vec = thread_pool.install(|| { + let chunk_results: Vec = thread_pool.install(|| { pgn_str_slices .par_chunks(chunk_size) .map(|chunk| { - let mut buffers = FlatBuffers::with_capacity(games_per_chunk, moves_per_game); + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); for &pgn in chunk { - let _ = parse_game_to_flat(pgn, &mut buffers); + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); } buffers }) .collect() }); - // Convert each FlatBuffers to ChunkData (numpy arrays) — no merge needed + // Convert each Buffers to ChunkData (numpy arrays) — no merge needed let chunk_data_vec: Vec = chunk_results .into_iter() - .map(|buf| flat_buffers_to_chunk_data(py, buf)) + .map(|buf| buffers_to_chunk_data(py, buf)) .collect::>>()?; build_parsed_games(py, chunk_data_vec) } -/// Convert a single FlatBuffers into a ChunkData with NumPy arrays. -fn flat_buffers_to_chunk_data(py: Python<'_>, buffers: FlatBuffers) -> PyResult { +/// Convert a single Buffers into a ChunkData with NumPy arrays. +fn buffers_to_chunk_data(py: Python<'_>, buffers: Buffers) -> PyResult { let n_games = buffers.num_games(); let total_positions = buffers.total_positions(); let total_moves = buffers.total_moves(); - // Compute local CSR offsets + // Compute all CSR offsets BEFORE any from_vec calls consume buffer fields let move_offsets_vec = buffers.compute_move_offsets(); let position_offsets_vec = buffers.compute_position_offsets(); + let has_legal_moves = !buffers.legal_move_counts.is_empty(); + let legal_move_offsets_vec = if has_legal_moves { + buffers.compute_legal_move_offsets() + } else { + Vec::new() + }; + let total_legal_moves_count = buffers.total_legal_moves(); // Boards: reshape from flat to (N_positions, 8, 8) let boards_array = PyArray1::from_vec(py, buffers.boards); @@ -176,6 +182,11 @@ fn flat_buffers_to_chunk_data(py: Python<'_>, buffers: FlatBuffers) -> PyResult< let legal_move_count_array = PyArray1::from_vec(py, buffers.legal_move_count); let valid_array = PyArray1::from_vec(py, buffers.valid); + let legal_move_from_squares_array = PyArray1::from_vec(py, buffers.legal_move_from_squares); + let legal_move_to_squares_array = PyArray1::from_vec(py, buffers.legal_move_to_squares); + let legal_move_promotions_array = PyArray1::from_vec(py, buffers.legal_move_promotions); + let legal_move_offsets_array = PyArray1::from_vec(py, legal_move_offsets_vec); + Ok(ChunkData { boards: boards_reshaped.unbind().into_any(), castling: castling_reshaped.unbind().into_any(), @@ -195,9 +206,16 @@ fn flat_buffers_to_chunk_data(py: Python<'_>, buffers: FlatBuffers) -> PyResult< legal_move_count: legal_move_count_array.unbind().into_any(), valid: valid_array.unbind().into_any(), headers: buffers.headers, + outcome: buffers.outcome, + comments: buffers.comments, + legal_move_from_squares: legal_move_from_squares_array.unbind().into_any(), + legal_move_to_squares: legal_move_to_squares_array.unbind().into_any(), + legal_move_promotions: legal_move_promotions_array.unbind().into_any(), + legal_move_offsets: legal_move_offsets_array.unbind().into_any(), num_games: n_games, num_moves: total_moves, num_positions: total_positions, + num_legal_moves: total_legal_moves_count, }) } @@ -279,131 +297,86 @@ fn build_parsed_games(py: Python<'_>, chunks: Vec) -> PyResult, pgn: &str, + store_comments: bool, store_legal_moves: bool, - store_board_states: bool, -) -> Result { - let mut reader = Reader::new(Cursor::new(pgn)); - let mut extractor = MoveExtractor::new(store_legal_moves, store_board_states); - match reader.read_game(&mut extractor) { - Ok(Some(_)) => Ok(extractor), - Ok(None) => Err("No game found in PGN".to_string()), - Err(err) => Err(format!("Parsing error: {}", err)), - } +) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; + let mut buffers = Buffers::with_capacity(1, 70, &config); + parse_game_to_buffers(pgn, &mut buffers, &config) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let chunk = buffers_to_chunk_data(py, buffers)?; + build_parsed_games(py, vec![chunk]) } -pub fn parse_multiple_games_native( - pgns: &Vec, +/// Parse multiple PGN game strings in parallel. +/// +/// Convenience wrapper for when you have a list of strings rather than an Arrow array. +#[pyfunction] +#[pyo3(signature = (pgns, num_threads=None, store_comments=false, store_legal_moves=false))] +fn parse_games_from_strings( + py: Python<'_>, + pgns: Vec, num_threads: Option, + store_comments: bool, store_legal_moves: bool, - store_board_states: bool, -) -> Result, String> { +) -> PyResult { + let config = ParseConfig { + store_comments, + store_legal_moves, + }; let num_threads = num_threads.unwrap_or_else(num_cpus::get); - // Build a custom Rayon thread pool with the desired number of threads - let thread_pool = ThreadPoolBuilder::new() - .num_threads(num_threads) - .build() - .expect("Failed to build Rayon thread pool"); - - thread_pool.install(|| { - pgns.par_iter() - .map(|pgn| parse_single_game_native(pgn, store_legal_moves, store_board_states)) - .collect() - }) -} + let n_games = pgns.len(); + if n_games == 0 { + let empty_chunk = buffers_to_chunk_data(py, Buffers::default())?; + return build_parsed_games(py, vec![empty_chunk]); + } -fn _parse_game_moves_from_arrow_chunks_native( - pgn_chunked_array: &PyChunkedArray, - num_threads: Option, - store_legal_moves: bool, - store_board_states: bool, -) -> Result, String> { - let num_threads = num_threads.unwrap_or_else(num_cpus::get); let thread_pool = ThreadPoolBuilder::new() .num_threads(num_threads) .build() - .map_err(|e| format!("Failed to build Rayon thread pool: {}", e))?; - - let mut num_elements = 0; - for chunk in pgn_chunked_array.chunks() { - num_elements += chunk.len(); - } - let mut pgn_str_slices: Vec<&str> = Vec::with_capacity(num_elements); - for chunk in pgn_chunked_array.chunks() { - if let Some(string_array) = chunk.as_any().downcast_ref::() { - for i in 0..string_array.len() { - if string_array.is_valid(i) { - pgn_str_slices.push(string_array.value(i)); - } - } - } else if let Some(large_string_array) = chunk.as_any().downcast_ref::() { - for i in 0..large_string_array.len() { - if large_string_array.is_valid(i) { - pgn_str_slices.push(large_string_array.value(i)); - } - } - } else { - return Err(format!( - "Unsupported array type in ChunkedArray: {:?}", - chunk.data_type() - )); - } - } + .map_err(|e| { + PyErr::new::(format!( + "Failed to build thread pool: {}", + e + )) + })?; - thread_pool.install(|| { - pgn_str_slices - .par_iter() - .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves, store_board_states)) - .collect::, String>>() - }) -} + let num_chunks = num_threads; + let chunk_size = (n_games + num_chunks - 1) / num_chunks; + let chunk_size = chunk_size.max(1); + let games_per_chunk = chunk_size; + let moves_per_game = 70; -// --- Python-facing wrappers (PyResult) --- -// TODO check if I can call py.allow_threads and release GIL -// see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ -#[pyfunction] -#[pyo3(signature = (pgn, store_legal_moves = false, store_board_states = false))] -/// Parses a single PGN game string. -fn parse_game( - pgn: &str, - store_legal_moves: bool, - store_board_states: bool, -) -> PyResult { - parse_single_game_native(pgn, store_legal_moves, store_board_states) - .map_err(pyo3::exceptions::PyValueError::new_err) -} + let chunk_results: Vec = thread_pool.install(|| { + pgns.par_chunks(chunk_size) + .map(|chunk| { + let mut buffers = Buffers::with_capacity(games_per_chunk, moves_per_game, &config); + for pgn in chunk { + let _ = parse_game_to_buffers(pgn, &mut buffers, &config); + } + buffers + }) + .collect() + }); -/// In parallel, parse a set of games -#[pyfunction] -#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false, store_board_states=false))] -fn parse_games( - pgns: Vec, - num_threads: Option, - store_legal_moves: bool, - store_board_states: bool, -) -> PyResult> { - parse_multiple_games_native(&pgns, num_threads, store_legal_moves, store_board_states) - .map_err(pyo3::exceptions::PyValueError::new_err) -} + let chunk_data_vec: Vec = chunk_results + .into_iter() + .map(|buf| buffers_to_chunk_data(py, buf)) + .collect::>>()?; -#[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false, store_board_states=false))] -fn parse_game_moves_arrow_chunked_array( - pgn_chunked_array: PyChunkedArray, - num_threads: Option, - store_legal_moves: bool, - store_board_states: bool, -) -> PyResult> { - _parse_game_moves_from_arrow_chunks_native( - &pgn_chunked_array, - num_threads, - store_legal_moves, - store_board_states, - ) - .map_err(pyo3::exceptions::PyValueError::new_err) + build_parsed_games(py, chunk_data_vec) } /// Parser for chess PGN notation @@ -411,155 +384,10 @@ fn parse_game_moves_arrow_chunked_array( fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_game, m)?)?; m.add_function(wrap_pyfunction!(parse_games, m)?)?; - m.add_function(wrap_pyfunction!(parse_game_moves_arrow_chunked_array, m)?)?; - m.add_function(wrap_pyfunction!(parse_games_flat, m)?)?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_function(wrap_pyfunction!(parse_games_from_strings, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) } - -#[cfg(test)] -mod pyucimove_tests { - use super::*; - - #[test] - fn test_parse_game_without_headers() { - let pgn = "1. Nf3 d5 2. e4 c5 3. exd5 e5 4. dxe6 0-1"; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert_eq!(extractor.moves.len(), 7); - assert_eq!(extractor.outcome, Some("Black".to_string())); - } - - #[test] - fn test_parse_game_with_standard_fen() { - // A game starting from a mid-game position - let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] - -3. Bb5 a6 4. Ba4 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!(extractor.valid_moves, "Moves should be valid"); - assert_eq!(extractor.moves.len(), 4); - } - - #[test] - fn test_parse_chess960_game() { - // Chess960 game with custom starting position - let pgn = r#"[Variant "chess960"] -[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] - -1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Chess960 moves should be valid with proper FEN" - ); - assert_eq!(extractor.moves.len(), 6); - } - - #[test] - fn test_parse_chess960_variant_case_insensitive() { - // Test that variant detection is case-insensitive - let pgn = r#"[Variant "Chess960"] -[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] - -1. g3 d5 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Should handle Chess960 case variations" - ); - } - - #[test] - fn test_parse_invalid_fen_falls_back() { - // Invalid FEN should fall back to default and mark invalid - let pgn = r#"[FEN "invalid fen string"] - -1. e4 e5 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - !extractor.valid_moves, - "Should mark as invalid when FEN parsing fails" - ); - } - - #[test] - fn test_fen_header_case_insensitive() { - // FEN header key should be case-insensitive - let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] - -3. Bb5 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Should handle lowercase 'fen' header" - ); - } - - #[test] - fn test_parse_game_with_custom_fen_no_variant() { - // A standard chess game starting from a mid-game position (no Variant header) - // Position after 1.e4 e5 2.Nf3 Nc6 3.Bb5 (Ruy Lopez) - let pgn = r#"[Event "Test Game"] - [FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] - - 3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert!( - extractor.valid_moves, - "Standard game with custom FEN should be valid" - ); - assert_eq!(extractor.moves.len(), 5); // a6, Ba4, Nf6, O-O, Be7 - } - - #[test] - fn test_parse_game_with_board_states() { - // Test that board states are populated when enabled - let pgn = "1. e4 e5 2. Nf3 Nc6 1-0"; - let result = parse_single_game_native(pgn, false, true); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert_eq!(extractor.moves.len(), 4); - // 5 positions: initial + 4 moves - let data = extractor - .board_state_data - .as_ref() - .expect("board_state_data should be Some"); - assert_eq!(data.board_states.len(), 5 * 64); - assert_eq!(data.en_passant_states.len(), 5); - assert_eq!(data.halfmove_clocks.len(), 5); - assert_eq!(data.turn_states.len(), 5); - assert_eq!(data.castling_states.len(), 5 * 4); - } - - #[test] - fn test_parse_game_without_board_states() { - // Test that board states are NOT populated when disabled - let pgn = "1. e4 e5 2. Nf3 Nc6 1-0"; - let result = parse_single_game_native(pgn, false, false); - assert!(result.is_ok()); - let extractor = result.unwrap(); - assert_eq!(extractor.moves.len(), 4); - // Board state data should be None when disabled - assert!(extractor.board_state_data.is_none()); - } -} diff --git a/src/python_bindings.rs b/src/python_bindings.rs index 9840e68..60d1d93 100644 --- a/src/python_bindings.rs +++ b/src/python_bindings.rs @@ -1,102 +1,8 @@ use numpy::{PyArray1, PyArray2, PyArrayMethods}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyList, PySlice}; -use shakmaty::{Role, Square}; use std::collections::HashMap; -// Definition of PyUciMove -#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] -#[derive(Clone, Debug)] -pub struct PyUciMove { - pub from_square: u8, - pub to_square: u8, - pub promotion: Option, -} - -#[pymethods] -impl PyUciMove { - #[new] - pub fn new(from_square: u8, to_square: u8, promotion: Option) -> Self { - PyUciMove { - from_square, - to_square, - promotion, - } - } - - #[getter] - fn get_from_square_name(&self) -> String { - Square::new(self.from_square as u32).to_string() - } - - #[getter] - fn get_to_square_name(&self) -> String { - Square::new(self.to_square as u32).to_string() - } - - #[getter] - fn get_promotion_name(&self) -> Option { - self.promotion.and_then(|p_u8| { - Role::try_from(p_u8) - .map(|role| format!("{:?}", role)) // Get the debug representation (e.g., "Queen") - .ok() - }) - } - - // __str__ method for Python representation - fn __str__(&self) -> String { - let promo_str = self.promotion.map_or("".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| role.char().to_string()) - .unwrap_or_else(|_| "".to_string()) // Handle potential error if u8 is not a valid Role - }); - format!( - "{}{}{}", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_str - ) - } - - // __repr__ for a more developer-friendly representation - fn __repr__(&self) -> String { - let promo_repr = self.promotion.map_or("None".to_string(), |p_u8| { - Role::try_from(p_u8) - .map(|role| format!("Some('{}')", role.char())) - .unwrap_or_else(|_| format!("Some(InvalidRole({}))", p_u8)) - }); - format!( - "PyUciMove(from_square={}, to_square={}, promotion={})", - Square::new(self.from_square as u32), - Square::new(self.to_square as u32), - promo_repr - ) - } -} - -#[pyclass] -/// Holds the status of a chess position. -#[derive(Clone)] -pub struct PositionStatus { - #[pyo3(get)] - pub is_checkmate: bool, - - #[pyo3(get)] - pub is_stalemate: bool, - - #[pyo3(get)] - pub legal_move_count: usize, - - #[pyo3(get)] - pub is_game_over: bool, - - #[pyo3(get)] - pub insufficient_material: (bool, bool), - - #[pyo3(get)] - pub turn: bool, -} - /// Internal per-chunk data. Not exposed to Python directly. /// Each chunk corresponds to one thread's output during parallel parsing. pub struct ChunkData { @@ -123,11 +29,22 @@ pub struct ChunkData { pub legal_move_count: Py, // (N_games,) u16 pub valid: Py, // (N_games,) bool pub headers: Vec>, + pub outcome: Vec>, // Per-game: "White", "Black", "Draw", "Unknown", or None + + // Optional: raw text comments (per-move), only populated when store_comments=true + pub comments: Vec>, + + // Optional: legal moves at each position (CSR arrays + offsets) + pub legal_move_from_squares: Py, // (N_legal_moves,) u8 + pub legal_move_to_squares: Py, // (N_legal_moves,) u8 + pub legal_move_promotions: Py, // (N_legal_moves,) i8 + pub legal_move_offsets: Py, // (N_positions + 1,) u32 // Metadata pub num_games: usize, pub num_moves: usize, pub num_positions: usize, + pub num_legal_moves: usize, } /// Chunked container for parsed chess games, optimized for ML training. @@ -526,6 +443,53 @@ impl PyChunkView { .clone() } + #[getter] + fn outcome(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .outcome + .clone() + } + + /// Raw text comments per move (only populated when store_comments=true). + #[getter] + fn comments(&self, py: Python<'_>) -> Vec> { + self.parent.borrow(py).chunks[self.chunk_idx] + .comments + .clone() + } + + /// Legal move from-squares for all positions in this chunk. + #[getter] + fn legal_move_from_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_from_squares + .clone_ref(py) + } + + /// Legal move to-squares for all positions in this chunk. + #[getter] + fn legal_move_to_squares(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_to_squares + .clone_ref(py) + } + + /// Legal move promotions for all positions in this chunk. + #[getter] + fn legal_move_promotions(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_promotions + .clone_ref(py) + } + + /// CSR offsets for legal moves (per-position). Length = num_positions + 1. + #[getter] + fn legal_move_offsets(&self, py: Python<'_>) -> Py { + self.parent.borrow(py).chunks[self.chunk_idx] + .legal_move_offsets + .clone_ref(py) + } + fn __repr__(&self, py: Python<'_>) -> String { let borrowed = self.parent.borrow(py); let chunk = &borrowed.chunks[self.chunk_idx]; @@ -847,6 +811,80 @@ impl PyGameView { .ok_or_else(|| pyo3::exceptions::PyIndexError::new_err("Invalid game index")) } + /// Whether the game is over (checkmate, stalemate, or both sides have insufficient material). + #[getter] + fn is_game_over(&self, py: Python<'_>) -> PyResult { + let checkmate = self.is_checkmate(py)?; + let stalemate = self.is_stalemate(py)?; + let (insuf_white, insuf_black) = self.is_insufficient(py)?; + Ok(checkmate || stalemate || (insuf_white && insuf_black)) + } + + /// Game outcome from movetext: "White", "Black", "Draw", "Unknown", or None. + #[getter] + fn outcome(&self, py: Python<'_>) -> PyResult> { + let borrowed = self.data.borrow(py); + Ok(borrowed.chunks[self.chunk_idx].outcome[self.local_idx].clone()) + } + + /// Raw text comments per move (only populated when store_comments=true). + /// Returns list[str | None] of length num_moves. + #[getter] + fn comments(&self, py: Python<'_>) -> PyResult>> { + let borrowed = self.data.borrow(py); + let chunk_comments = &borrowed.chunks[self.chunk_idx].comments; + if chunk_comments.is_empty() { + return Ok(Vec::new()); + } + Ok(chunk_comments[self.move_start..self.move_end].to_vec()) + } + + /// Legal moves at each position in this game. + /// Returns list of lists: [[from, to, promotion], ...] per position. + /// Only populated when store_legal_moves=true. + #[getter] + fn legal_moves(&self, py: Python<'_>) -> PyResult>> { + let borrowed = self.data.borrow(py); + let chunk = &borrowed.chunks[self.chunk_idx]; + + // Check if legal moves were stored + if chunk.num_legal_moves == 0 { + return Ok(Vec::new()); + } + + let offsets_arr = chunk.legal_move_offsets.bind(py); + let offsets_arr: &Bound<'_, PyArray1> = offsets_arr.cast()?; + let offsets_ro = offsets_arr.readonly(); + let offsets_slice = offsets_ro.as_slice()?; + + let from_arr = chunk.legal_move_from_squares.bind(py); + let from_arr: &Bound<'_, PyArray1> = from_arr.cast()?; + let from_ro = from_arr.readonly(); + let from_slice = from_ro.as_slice()?; + + let to_arr = chunk.legal_move_to_squares.bind(py); + let to_arr: &Bound<'_, PyArray1> = to_arr.cast()?; + let to_ro = to_arr.readonly(); + let to_slice = to_ro.as_slice()?; + + let promo_arr = chunk.legal_move_promotions.bind(py); + let promo_arr: &Bound<'_, PyArray1> = promo_arr.cast()?; + let promo_ro = promo_arr.readonly(); + let promo_slice = promo_ro.as_slice()?; + + let mut result = Vec::with_capacity(self.pos_end - self.pos_start); + for pos_idx in self.pos_start..self.pos_end { + let start = offsets_slice[pos_idx] as usize; + let end = offsets_slice[pos_idx + 1] as usize; + let mut moves = Vec::with_capacity(end - start); + for i in start..end { + moves.push((from_slice[i], to_slice[i], promo_slice[i])); + } + result.push(moves); + } + Ok(result) + } + // === Convenience methods === /// Get UCI string for move at index. @@ -933,74 +971,3 @@ impl PyGameView { )) } } - -#[cfg(test)] -mod tests { - use super::*; - use shakmaty::{Role, Square}; - - #[test] - fn test_py_uci_move_no_promotion() { - let uci_move = PyUciMove::new(Square::E2 as u8, Square::E4 as u8, None); - assert_eq!(uci_move.from_square, Square::E2 as u8); - assert_eq!(uci_move.to_square, Square::E4 as u8); - assert_eq!(uci_move.promotion, None); - assert_eq!(uci_move.get_from_square_name(), "e2"); - assert_eq!(uci_move.get_to_square_name(), "e4"); - assert_eq!(uci_move.get_promotion_name(), None); - assert_eq!(uci_move.__str__(), "e2e4"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e2, to_square=e4, promotion=None)" - ); - } - - #[test] - fn test_py_uci_move_with_queen_promotion() { - let uci_move = PyUciMove::new(Square::E7 as u8, Square::E8 as u8, Some(Role::Queen as u8)); - assert_eq!(uci_move.from_square, Square::E7 as u8); - assert_eq!(uci_move.to_square, Square::E8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Queen as u8)); - assert_eq!(uci_move.get_from_square_name(), "e7"); - assert_eq!(uci_move.get_to_square_name(), "e8"); - assert_eq!(uci_move.get_promotion_name(), Some("Queen".to_string())); - assert_eq!(uci_move.__str__(), "e7e8q"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=e7, to_square=e8, promotion=Some('q'))" - ); - } - - #[test] - fn test_py_uci_move_with_rook_promotion() { - let uci_move = PyUciMove::new(Square::A7 as u8, Square::A8 as u8, Some(Role::Rook as u8)); - assert_eq!(uci_move.from_square, Square::A7 as u8); - assert_eq!(uci_move.to_square, Square::A8 as u8); - assert_eq!(uci_move.promotion, Some(Role::Rook as u8)); - assert_eq!(uci_move.get_from_square_name(), "a7"); - assert_eq!(uci_move.get_to_square_name(), "a8"); - assert_eq!(uci_move.get_promotion_name(), Some("Rook".to_string())); - assert_eq!(uci_move.__str__(), "a7a8r"); - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=a7, to_square=a8, promotion=Some('r'))" - ); - } - - #[test] - fn test_py_uci_move_invalid_promotion_val() { - // Test with a u8 value that doesn't correspond to a valid Role - let uci_move = PyUciMove::new(Square::B7 as u8, Square::B8 as u8, Some(99)); // 99 is not a valid Role - assert_eq!(uci_move.from_square, Square::B7 as u8); - assert_eq!(uci_move.to_square, Square::B8 as u8); - assert_eq!(uci_move.promotion, Some(99)); - assert_eq!(uci_move.get_from_square_name(), "b7"); - assert_eq!(uci_move.get_to_square_name(), "b8"); - assert_eq!(uci_move.get_promotion_name(), None); // Should be None as 99 is invalid - assert_eq!(uci_move.__str__(), "b7b8"); // Should produce no promotion char - assert_eq!( - uci_move.__repr__(), - "PyUciMove(from_square=b7, to_square=b8, promotion=Some(InvalidRole(99)))" - ); - } -} diff --git a/src/test.py b/src/test.py index cec9733..59de284 100644 --- a/src/test.py +++ b/src/test.py @@ -7,693 +7,16 @@ import pyarrow as pa -class TestPgnExtraction(unittest.TestCase): - def run_extractor(self, pgn_string): - extractor = rust_pgn_reader_python_binding.parse_game(pgn_string) - return extractor - - def test_short_pgn(self): - pgn_moves = "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 8. Be3 a6 9. Ba4 Bd7 10. Bxc5 dxc5 11. Bxc6 Bxc6 12. Nxe5 Bb5 13. Re1 Re8 14. c4 Rxe5 15. cxb5 axb5 16. Nc3 b4 17. Nd5 Nxd5 18. exd5 Rxe1+ 19. Qxe1 Qxd5 20. Qe3 b6 21. b3 c6 22. Qe2 b5 23. Rd1 Qd4 24. Qe7 Rxa2 25. Qe8+ Kh7 26. Qxf7 c4 27. Qf5+ Kh8 28. Qf8+ Kh7 29. Qf5+ Kh8 30. Qf8+ Kh7 31. Qf5+ Kh8 32. Qf8+ Kh7 1/2-1/2" - extractor = self.run_extractor(pgn_moves) - - moves_reference = [ - "e2e4", - "e7e5", - "g1f3", - "b8c6", - "f1b5", - "g8f6", - "e1g1", - "f8c5", - "d2d3", - "d7d6", - "h2h3", - "h7h6", - "c2c3", - "e8g8", - "c1e3", - "a7a6", - "b5a4", - "c8d7", - "e3c5", - "d6c5", - "a4c6", - "d7c6", - "f3e5", - "c6b5", - "f1e1", - "f8e8", - "c3c4", - "e8e5", - "c4b5", - "a6b5", - "b1c3", - "b5b4", - "c3d5", - "f6d5", - "e4d5", - "e5e1", - "d1e1", - "d8d5", - "e1e3", - "b7b6", - "b2b3", - "c7c6", - "e3e2", - "b6b5", - "a1d1", - "d5d4", - "e2e7", - "a8a2", - "e7e8", - "g8h7", - "e8f7", - "c5c4", - "f7f5", - "h7h8", - "f5f8", - "h8h7", - "f8f5", - "h7h8", - "f5f8", - "h8h7", - "f8f5", - "h7h8", - "f5f8", - "h8h7", - ] - - comments_reference = [ - "asdf", - None, - None, - None, - None, - None, - "hello", - ] + [None for _ in range(57)] - - valid_reference = True - evals_reference = [None for _ in range(len(moves_reference))] - clock_times_reference = [None for _ in range(len(moves_reference))] - - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_full_pgn(self): - pgn_moves = """ -[Event "Rated Classical game"] -[Site "https://lichess.org/lhy6ehiv"] -[White "goerch"] -[Black "niltonrosao001"] -[Result "0-1"] -[UTCDate "2013.06.30"] -[UTCTime "22:10:02"] -[WhiteElo "1702"] -[BlackElo "2011"] -[WhiteRatingDiff "-3"] -[BlackRatingDiff "+5"] -[ECO "A46"] -[Opening "Indian Game: Spielmann-Indian"] -[TimeControl "600+8"] -[Termination "Normal"] - -1. d4 Nf6 2. Nf3 c5 3. e3 b6 4. Nc3 e6 5. Bb5 a6 6. Bd3 Bb7 7. O-O b5 8. b3 d5 9. Bb2 Nbd7 10. a4 b4 11. Ne2 Bd6 12. c4 bxc3 13. Bxc3 O-O 14. Ng3 Rc8 15. dxc5 Nxc5 16. Nd4 Nxd3 17. Qxd3 Qb6 18. Rab1 Bb4 19. Nge2 Ne4 20. Rfc1 Nxc3 21. Nxc3 Rc7 22. Na2 Rfc8 23. Rc2 g6 24. Nxb4 Qxb4 25. Rbc1 Rxc2 26. Nxc2 Qc3 27. Qxc3 Rxc3 28. Kf1 d4 29. exd4 Be4 30. Ke1 Rxc2 31. Rxc2 Bxc2 32. Kd2 Bxb3 33. a5 Bd5 0-1 - -""" - extractor = self.run_extractor(pgn_moves) - - moves_reference = [ - "d2d4", - "g8f6", - "g1f3", - "c7c5", - "e2e3", - "b7b6", - "b1c3", - "e7e6", - "f1b5", - "a7a6", - "b5d3", - "c8b7", - "e1g1", - "b6b5", - "b2b3", - "d7d5", - "c1b2", - "b8d7", - "a2a4", - "b5b4", - "c3e2", - "f8d6", - "c2c4", - "b4c3", - "b2c3", - "e8g8", - "e2g3", - "a8c8", - "d4c5", - "d7c5", - "f3d4", - "c5d3", - "d1d3", - "d8b6", - "a1b1", - "d6b4", - "g3e2", - "f6e4", - "f1c1", - "e4c3", - "e2c3", - "c8c7", - "c3a2", - "f8c8", - "c1c2", - "g7g6", - "a2b4", - "b6b4", - "b1c1", - "c7c2", - "d4c2", - "b4c3", - "d3c3", - "c8c3", - "g1f1", - "d5d4", - "e3d4", - "b7e4", - "f1e1", - "c3c2", - "c1c2", - "e4c2", - "e1d2", - "c2b3", - "a4a5", - "b3d5", - ] - comments_reference = [None for _ in range(len(moves_reference))] - valid_reference = True - evals_reference = [None for _ in range(len(moves_reference))] - clock_times_reference = [None for _ in range(len(moves_reference))] - headers_reference = [ - ("Event", "Rated Classical game"), - ("Site", "https://lichess.org/lhy6ehiv"), - ("White", "goerch"), - ("Black", "niltonrosao001"), - ("Result", "0-1"), - ("UTCDate", "2013.06.30"), - ("UTCTime", "22:10:02"), - ("WhiteElo", "1702"), - ("BlackElo", "2011"), - ("WhiteRatingDiff", "-3"), - ("BlackRatingDiff", "+5"), - ("ECO", "A46"), - ("Opening", "Indian Game: Spielmann-Indian"), - ("TimeControl", "600+8"), - ("Termination", "Normal"), - ] - - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - self.assertTrue(extractor.headers == headers_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_full_pgn_annotated(self): - pgn_moves = """ - 1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } - 2. Nf3 { [%eval 0.25] [%clk 0:00:29] } 2... Nc6 { [%eval 0.33] [%clk 0:00:30] } - 3. Bc4 { [%eval -0.13] [%clk 0:00:28] } 3... e6 { [%eval -0.04] [%clk 0:00:30] } - 4. c3 { [%eval -0.4] [%clk 0:00:27] } 4... b5? { [%eval 1.18] [%clk 0:00:30] } - 5. Bb3?! { [%eval 0.21] [%clk 0:00:26] } 5... c4 { [%eval 0.32] [%clk 0:00:29] } - 6. Bc2 { [%eval 0.2] [%clk 0:00:25] } 6... a5 { [%eval 0.6] [%clk 0:00:29] } - 7. d4 { [%eval 0.29] [%clk 0:00:23] } 7... cxd3 { [%eval 0.6] [%clk 0:00:27] } - 8. Qxd3 { [%eval 0.12] [%clk 0:00:22] } 8... Nf6 { [%eval 0.52] [%clk 0:00:26] } - 9. e5 { [%eval 0.39] [%clk 0:00:21] } 9... Nd5 { [%eval 0.45] [%clk 0:00:25] } - 10. Bg5?! { [%eval -0.44] [%clk 0:00:18] } 10... Qc7 { [%eval -0.12] [%clk 0:00:23] } - 11. Nbd2?? { [%eval -3.15] [%clk 0:00:14] } 11... h6 { [%eval -2.99] [%clk 0:00:23] } - 12. Bh4 { [%eval -3.0] [%clk 0:00:11] } 12... Ba6? { [%eval -0.12] [%clk 0:00:23] } - 13. b3?? { [%eval -4.14] [%clk 0:00:02] } 13... Nf4? { [%eval -2.73] [%clk 0:00:21] } 0-1 - """ - extractor = self.run_extractor(pgn_moves) - - moves_reference = [ - "e2e4", - "c7c5", - "g1f3", - "b8c6", - "f1c4", - "e7e6", - "c2c3", - "b7b5", - "c4b3", - "c5c4", - "b3c2", - "a7a5", - "d2d4", - "c4d3", - "d1d3", - "g8f6", - "e4e5", - "f6d5", - "c1g5", - "d8c7", - "b1d2", - "h7h6", - "g5h4", - "c8a6", - "b2b3", - "d5f4", - ] - comments_reference = [ - "", - ] * 26 - valid_reference = True - evals_reference = [ - 0.17, - 0.19, - 0.25, - 0.33, - -0.13, - -0.04, - -0.4, - 1.18, - 0.21, - 0.32, - 0.2, - 0.6, - 0.29, - 0.6, - 0.12, - 0.52, - 0.39, - 0.45, - -0.44, - -0.12, - -3.15, - -2.99, - -3.0, - -0.12, - -4.14, - -2.73, - ] - clock_times_reference = [ - (0, 0, 30), - (0, 0, 30), - (0, 0, 29), - (0, 0, 30), - (0, 0, 28), - (0, 0, 30), - (0, 0, 27), - (0, 0, 30), - (0, 0, 26), - (0, 0, 29), - (0, 0, 25), - (0, 0, 29), - (0, 0, 23), - (0, 0, 27), - (0, 0, 22), - (0, 0, 26), - (0, 0, 21), - (0, 0, 25), - (0, 0, 18), - (0, 0, 23), - (0, 0, 14), - (0, 0, 23), - (0, 0, 11), - (0, 0, 23), - (0, 0, 2), - (0, 0, 21), - ] - self.assertTrue([str(move) for move in extractor.moves] == moves_reference) - self.assertTrue(extractor.comments == comments_reference) - self.assertTrue(extractor.valid_moves == valid_reference) - self.assertTrue(extractor.evals == evals_reference) - self.assertTrue(extractor.clock_times == clock_times_reference) - - assert extractor.position_status is not None # appease the type checker - self.assertFalse(extractor.position_status.is_checkmate) - self.assertFalse(extractor.position_status.is_stalemate) - self.assertFalse(extractor.position_status.is_game_over) - - def test_multithreaded(self): - pgns = [ - "1. Nf3 g6 2. b3 Bg7 3. Nc3 e5 4. Bb2 e4 5. Ng1 d6 6. Rb1 a5 7. Nxe4 Bxb2 8. Rxb2 Nf6 9. Nxf6+ Qxf6 10. Rb1 Ra6 11. e3 Rb6 12. d4 a4 13. bxa4 Nc6 14. Rxb6 cxb6 15. Bb5 Bd7 16. Bxc6 Bxc6 17. Nf3 Bxa4 18. O-O O-O 19. Re1 Rc8 20. Re2 d5 21. Ne5 Qf5 22. Qd3 Bxc2 23. Qxf5 Bxf5 24. h3 b5 25. Rb2 f6 26. Ng4 Rc6 27. Nh6+ Kg7 28. Nxf5+ gxf5 29. Rxb5 Rc7 30. Rxd5 Kg6 31. f4 Kh5 32. Rxf5+ Kh4 33. Rxf6 Kg3 34. d5 Rc1# 0-1", - "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O", - ] - extractor = rust_pgn_reader_python_binding.parse_games(pgns) - - moves_reference = [ - [ - "g1f3", - "g7g6", - "b2b3", - "f8g7", - "b1c3", - "e7e5", - "c1b2", - "e5e4", - "f3g1", - "d7d6", - "a1b1", - "a7a5", - "c3e4", - "g7b2", - "b1b2", - "g8f6", - "e4f6", - "d8f6", - "b2b1", - "a8a6", - "e2e3", - "a6b6", - "d2d4", - "a5a4", - "b3a4", - "b8c6", - "b1b6", - "c7b6", - "f1b5", - "c8d7", - "b5c6", - "d7c6", - "g1f3", - "c6a4", - "e1g1", - "e8g8", - "f1e1", - "f8c8", - "e1e2", - "d6d5", - "f3e5", - "f6f5", - "d1d3", - "a4c2", - "d3f5", - "c2f5", - "h2h3", - "b6b5", - "e2b2", - "f7f6", - "e5g4", - "c8c6", - "g4h6", - "g8g7", - "h6f5", - "g6f5", - "b2b5", - "c6c7", - "b5d5", - "g7g6", - "f2f4", - "g6h5", - "d5f5", - "h5h4", - "f5f6", - "h4g3", - "d4d5", - "c7c1", - ], - [ - "e2e4", - "e7e5", - "g1f3", - "b8c6", - "f1b5", - "g8f6", - "e1g1", - "f8c5", - "d2d3", - "d7d6", - "h2h3", - "h7h6", - "c2c3", - "e8g8", - ], - ] - - comments_reference = [ - [None for _ in range(len(moves_reference[0]))], - [ - "asdf", - None, - None, - None, - None, - None, - "hello", - None, - None, - None, - None, - None, - None, - None, - ], - ] - - self.assertTrue(extractor[0].comments == comments_reference[0]) - self.assertTrue(extractor[1].comments == comments_reference[1]) - - self.assertTrue( - [str(move) for move in extractor[0].moves] == moves_reference[0] - ) - self.assertTrue( - [str(move) for move in extractor[1].moves] == moves_reference[1] - ) - assert extractor[0].position_status is not None # appease the type checker - - self.assertTrue(extractor[0].position_status.is_checkmate) - self.assertFalse(extractor[0].position_status.is_stalemate) - self.assertTrue(extractor[0].position_status.is_game_over) - self.assertTrue(extractor[0].position_status.legal_move_count == 0) - self.assertTrue(extractor[0].position_status.turn == 1) - self.assertTrue(extractor[0].position_status.insufficient_material == (0, 0)) - - self.assertTrue(extractor[1].position_status is None) - extractor[1].update_position_status() - assert extractor[1].position_status is not None # appease the type checker - self.assertFalse(extractor[1].position_status.is_checkmate) - self.assertFalse(extractor[1].position_status.is_stalemate) - self.assertFalse(extractor[1].position_status.is_game_over) - self.assertTrue(extractor[1].position_status.legal_move_count == 36) - self.assertTrue(extractor[1].position_status.turn == 1) - self.assertTrue(extractor[1].position_status.insufficient_material == (0, 0)) - - def test_castling(self): - pgn_moves = """ - 1. e4 e5 2. Bc4 c6 3. Nf3 d6 4. Rg1 f6 5. Rh1 g6 6. Ke2 b6 7. Ke1 g5 - """ - - extractor = self.run_extractor(pgn_moves) - - castling_reference = [ - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, True, True, True), - (True, False, True, True), - (True, False, True, True), - (True, False, True, True), - (True, False, True, True), - (False, False, True, True), - (False, False, True, True), - (False, False, True, True), - (False, False, True, True), - ] - - self.assertTrue(extractor.castling_rights == castling_reference) - - def test_parse_game_moves_arrow_chunked_array(self): - pgns = [ - "1. Nf3 g6 2. b3 Bg7 3. Nc3 e5 4. Bb2 e4 5. Ng1 d6 6. Rb1 a5 7. Nxe4 Bxb2 8. Rxb2 Nf6 9. Nxf6+ Qxf6 10. Rb1 Ra6 11. e3 Rb6 12. d4 a4 13. bxa4 Nc6 14. Rxb6 cxb6 15. Bb5 Bd7 16. Bxc6 Bxc6 17. Nf3 Bxa4 18. O-O O-O 19. Re1 Rc8 20. Re2 d5 21. Ne5 Qf5 22. Qd3 Bxc2 23. Qxf5 Bxf5 24. h3 b5 25. Rb2 f6 26. Ng4 Rc6 27. Nh6+ Kg7 28. Nxf5+ gxf5 29. Rxb5 Rc7 30. Rxd5 Kg6 31. f4 Kh5 32. Rxf5+ Kh4 33. Rxf6 Kg3 34. d5 Rc1# 0-1", - "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O", - ] - - # Create a PyArrow ChunkedArray - arrow_array = pa.array(pgns, type=pa.string()) - chunked_array = pa.chunked_array([arrow_array]) - - extractors = ( - rust_pgn_reader_python_binding.parse_game_moves_arrow_chunked_array( - chunked_array - ) - ) - - moves_reference = [ - [ - "g1f3", - "g7g6", - "b2b3", - "f8g7", - "b1c3", - "e7e5", - "c1b2", - "e5e4", - "f3g1", - "d7d6", - "a1b1", - "a7a5", - "c3e4", - "g7b2", - "b1b2", - "g8f6", - "e4f6", - "d8f6", - "b2b1", - "a8a6", - "e2e3", - "a6b6", - "d2d4", - "a5a4", - "b3a4", - "b8c6", - "b1b6", - "c7b6", - "f1b5", - "c8d7", - "b5c6", - "d7c6", - "g1f3", - "c6a4", - "e1g1", - "e8g8", - "f1e1", - "f8c8", - "e1e2", - "d6d5", - "f3e5", - "f6f5", - "d1d3", - "a4c2", - "d3f5", - "c2f5", - "h2h3", - "b6b5", - "e2b2", - "f7f6", - "e5g4", - "c8c6", - "g4h6", - "g8g7", - "h6f5", - "g6f5", - "b2b5", - "c6c7", - "b5d5", - "g7g6", - "f2f4", - "g6h5", - "d5f5", - "h5h4", - "f5f6", - "h4g3", - "d4d5", - "c7c1", - ], - [ - "e2e4", - "e7e5", - "g1f3", - "b8c6", - "f1b5", - "g8f6", - "e1g1", - "f8c5", - "d2d3", - "d7d6", - "h2h3", - "h7h6", - "c2c3", - "e8g8", - ], - ] - - self.assertTrue( - [str(move) for move in extractors[0].moves] == moves_reference[0] - ) - self.assertTrue( - [str(move) for move in extractors[1].moves] == moves_reference[1] - ) - - comments_reference = [ - [None for _ in range(len(moves_reference[0]))], - [ - "asdf", - None, - None, - None, - None, - None, - "hello", - None, - None, - None, - None, - None, - None, - None, - ], - ] - - self.assertTrue(extractors[0].comments == comments_reference[0]) - self.assertTrue(extractors[1].comments == comments_reference[1]) - - extractors[0].update_position_status() # Ensure status is calculated - assert extractors[0].position_status is not None # appease the type checker - self.assertTrue(extractors[0].position_status.is_checkmate) - self.assertFalse(extractors[0].position_status.is_stalemate) - self.assertTrue(extractors[0].position_status.is_game_over) - self.assertTrue(extractors[0].position_status.legal_move_count == 0) - self.assertTrue( - extractors[0].position_status.turn == True - ) # White's turn, but Black delivered checkmate - self.assertTrue( - extractors[0].position_status.insufficient_material == (False, False) - ) - - self.assertTrue( - extractors[1].position_status is None - ) # Not set by default for parse_game_moves_arrow_chunked_array - extractors[1].update_position_status() - assert extractors[1].position_status is not None # appease the type checker - self.assertFalse(extractors[1].position_status.is_checkmate) - self.assertFalse(extractors[1].position_status.is_stalemate) - self.assertFalse(extractors[1].position_status.is_game_over) - self.assertTrue(extractors[1].position_status.legal_move_count == 36) - self.assertTrue(extractors[1].position_status.turn == True) # White's turn - self.assertTrue( - extractors[1].position_status.insufficient_material == (False, False) - ) - - -class TestParsedGamesFlat(unittest.TestCase): +class TestParsedGames(unittest.TestCase): def test_basic_structure(self): - """Test basic flat parsing returns correct structure.""" + """Test basic parsing returns correct structure.""" pgns = [ "1. e4 e5 2. Nf3 Nc6 3. Bb5 1-0", "1. d4 d5 2. c4 e6 0-1", ] chunked = pa.chunked_array([pa.array(pgns)]) # Use 1 thread to get a single chunk for predictable array shapes - result = rust_pgn_reader_python_binding.parse_games_flat(chunked, num_threads=1) + result = rust_pgn_reader_python_binding.parse_games(chunked, num_threads=1) # Check game count self.assertEqual(len(result), 2) @@ -725,7 +48,7 @@ def test_initial_board_encoding(self): """Test initial board state encoding.""" pgns = ["1. e4 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) initial = result[0].initial_board # First position @@ -751,7 +74,7 @@ def test_board_after_move(self): """Test board state updates correctly after move.""" pgns = ["1. e4 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) # Position 0: initial, Position 1: after e4 after_e4 = result[0].boards[1] @@ -765,7 +88,7 @@ def test_en_passant_tracking(self): """Test en passant square tracking.""" pgns = ["1. e4 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] # Initial: no en passant @@ -778,7 +101,7 @@ def test_castling_rights(self): # White moves rook, losing kingside castling pgns = ["1. e4 e5 2. Nf3 Nc6 3. Rg1 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] # Initial: all castling [K, Q, k, q] = [True, True, True, True] @@ -795,7 +118,7 @@ def test_turn_tracking(self): """Test side-to-move tracking.""" pgns = ["1. e4 e5 2. Nf3 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] # Initial: white to move @@ -809,7 +132,7 @@ def test_game_view_access(self): """Test GameView provides correct slices.""" pgns = ["1. e4 e5 2. Nf3 1-0", "1. d4 d5 0-1"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game0 = result[0] self.assertEqual(len(game0), 3) @@ -825,7 +148,7 @@ def test_game_view_move_uci(self): """Test GameView UCI move conversion.""" pgns = ["1. e4 e5 2. Nf3 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] self.assertEqual(game.move_uci(0), "e2e4") @@ -838,17 +161,17 @@ def test_iteration(self): """Test iteration over games.""" pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) games = list(result) self.assertEqual(len(games), 3) self.assertIsInstance(games[0], PyGameView) def test_slicing(self): - """Test slicing returns BatchSlice.""" + """Test slicing returns list of game views.""" pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) sliced = result[1:3] self.assertEqual(len(sliced), 2) @@ -859,7 +182,7 @@ def test_position_to_game_mapping(self): """Test position to game index mapping.""" pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves (3 pos), 1 move (2 pos) chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) # Positions: 0,1,2 (game 0), 3,4 (game 1) pos_indices = np.array([0, 1, 2, 3, 4]) @@ -871,7 +194,7 @@ def test_move_to_game_mapping(self): """Test move to game index mapping.""" pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 moves, 1 move chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) move_indices = np.array([0, 1, 2]) game_indices = result.move_to_game(move_indices) @@ -882,9 +205,8 @@ def test_position_to_game_accepts_various_dtypes(self): """Test position_to_game accepts various integer dtypes.""" pgns = ["1. e4 e5 1-0", "1. d4 0-1"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) - # Test various integer dtypes (int64 is optimal, others are converted) for dtype in [np.int32, np.int64, np.uint32, np.uint64]: pos_indices = np.array([0, 1, 2, 3, 4], dtype=dtype) game_indices = result.position_to_game(pos_indices) @@ -894,9 +216,8 @@ def test_move_to_game_accepts_various_dtypes(self): """Test move_to_game accepts various integer dtypes.""" pgns = ["1. e4 e5 1-0", "1. d4 0-1"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) - # Test various integer dtypes (int64 is optimal, others are converted) for dtype in [np.int32, np.int64, np.uint32, np.uint64]: move_indices = np.array([0, 1, 2], dtype=dtype) game_indices = result.move_to_game(move_indices) @@ -906,7 +227,7 @@ def test_clocks_and_evals(self): """Test clock and eval parsing.""" pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... e5 { [%eval 0.19] [%clk 0:00:29] } 1-0""" chunked = pa.chunked_array([pa.array([pgn])]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] self.assertAlmostEqual(game.evals[0], 0.17, places=2) @@ -918,7 +239,7 @@ def test_missing_clocks_evals_are_nan(self): """Test missing clocks/evals are NaN.""" pgns = ["1. e4 e5 1-0"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] self.assertTrue(np.isnan(game.clocks[0])) @@ -932,7 +253,7 @@ def test_headers_preserved(self): 1. e4 1-0""" chunked = pa.chunked_array([pa.array([pgn])]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] self.assertEqual(game.headers["White"], "Player1") @@ -947,7 +268,7 @@ def test_invalid_game_flagged(self): "1. d4 d5 0-1", # Valid ] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) self.assertEqual(len(result), 3) self.assertTrue(result[0].is_valid) @@ -959,7 +280,7 @@ def test_checkmate_detection(self): # Scholar's mate pgn = "1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0" chunked = pa.chunked_array([pa.array([pgn])]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] self.assertTrue(game.is_checkmate) @@ -968,12 +289,11 @@ def test_checkmate_detection(self): def test_promotion(self): """Test promotion encoding.""" - # Simplified position reaching promotion pgn = """[FEN "8/P7/8/8/8/8/8/4K2k w - - 0 1"] 1. a8=Q 1-0""" chunked = pa.chunked_array([pa.array([pgn])]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) game = result[0] # Promotion to queen = 5 @@ -984,7 +304,7 @@ def test_num_properties(self): """Test num_games, num_moves, num_positions properties.""" pgns = ["1. e4 e5 1-0", "1. d4 0-1"] # 2 + 1 = 3 moves, 3 + 2 = 5 positions chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) self.assertEqual(result.num_games, 2) self.assertEqual(result.num_moves, 3) @@ -994,7 +314,7 @@ def test_negative_indexing(self): """Test negative index access.""" pgns = ["1. e4 1-0", "1. d4 0-1", "1. c4 1/2-1/2"] chunked = pa.chunked_array([pa.array(pgns)]) - result = rust_pgn_reader_python_binding.parse_games_flat(chunked) + result = rust_pgn_reader_python_binding.parse_games(chunked) # -1 should be the last game last_game = result[-1] @@ -1004,6 +324,445 @@ def test_negative_indexing(self): # c2 = file 2 (c) + rank 1 (2nd rank) * 8 = 2 + 8 = 10 self.assertEqual(last_game.from_squares[0], 10) # c2 + def test_outcome(self): + """Test outcome is parsed from movetext.""" + pgns = ["1. e4 e5 1-0", "1. d4 d5 0-1", "1. c4 1/2-1/2"] + chunked = pa.chunked_array([pa.array(pgns)]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + + self.assertEqual(result[0].outcome, "White") + self.assertEqual(result[1].outcome, "Black") + self.assertEqual(result[2].outcome, "Draw") + + def test_outcome_without_headers(self): + """Test outcome works for PGNs without Result header.""" + pgn = "1. e4 e5 0-1" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertEqual(result[0].outcome, "Black") + + def test_is_game_over(self): + """Test is_game_over derived property.""" + # Scholar's mate + pgn = "1. e4 e5 2. Qh5 Nc6 3. Bc4 Nf6 4. Qxf7# 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertTrue(result[0].is_game_over) + + # Not game over + pgn2 = "1. e4 e5 1-0" + chunked2 = pa.chunked_array([pa.array([pgn2])]) + result2 = rust_pgn_reader_python_binding.parse_games(chunked2) + self.assertFalse(result2[0].is_game_over) + + def test_comments_disabled_by_default(self): + """Test comments are empty by default.""" + pgn = "1. e4 { a comment } e5 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games(chunked) + self.assertEqual(result[0].comments, []) + + def test_comments_enabled(self): + """Test comments are stored when enabled.""" + pgn = "1. e4 {asdf} e5 { [%eval 0.19] } 1-0" + chunked = pa.chunked_array([pa.array([pgn])]) + result = rust_pgn_reader_python_binding.parse_games( + chunked, store_comments=True + ) + comments = result[0].comments + self.assertEqual(len(comments), 2) + self.assertEqual(comments[0], "asdf") + self.assertIsNotNone(comments[1]) # eval-only comment + + def test_parse_game_string(self): + """Test parse_game convenience function for single string.""" + pgn = "1. e4 e5 2. Nf3 Nc6 1-0" + result = rust_pgn_reader_python_binding.parse_game(pgn) + self.assertEqual(result.num_games, 1) + self.assertEqual(result[0].moves_uci(), ["e2e4", "e7e5", "g1f3", "b8c6"]) + self.assertEqual(result[0].outcome, "White") + + def test_parse_games_from_strings(self): + """Test parse_games_from_strings convenience function.""" + pgns = [ + "1. e4 e5 1-0", + "1. d4 d5 0-1", + ] + result = rust_pgn_reader_python_binding.parse_games_from_strings(pgns) + self.assertEqual(result.num_games, 2) + self.assertTrue(result[0].is_valid) + self.assertTrue(result[1].is_valid) + + def test_long_game_with_comments(self): + """Test a long game with inline comments, verifying all moves.""" + pgn = "1. e4 {asdf} e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O {hello} Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 8. Be3 a6 9. Ba4 Bd7 10. Bxc5 dxc5 11. Bxc6 Bxc6 12. Nxe5 Bb5 13. Re1 Re8 14. c4 Rxe5 15. cxb5 axb5 16. Nc3 b4 17. Nd5 Nxd5 18. exd5 Rxe1+ 19. Qxe1 Qxd5 20. Qe3 b6 21. b3 c6 22. Qe2 b5 23. Rd1 Qd4 24. Qe7 Rxa2 25. Qe8+ Kh7 26. Qxf7 c4 27. Qf5+ Kh8 28. Qf8+ Kh7 29. Qf5+ Kh8 30. Qf8+ Kh7 31. Qf5+ Kh8 32. Qf8+ Kh7 1/2-1/2" + + result = rust_pgn_reader_python_binding.parse_game(pgn, store_comments=True) + game = result[0] + + moves_reference = [ + "e2e4", + "e7e5", + "g1f3", + "b8c6", + "f1b5", + "g8f6", + "e1g1", + "f8c5", + "d2d3", + "d7d6", + "h2h3", + "h7h6", + "c2c3", + "e8g8", + "c1e3", + "a7a6", + "b5a4", + "c8d7", + "e3c5", + "d6c5", + "a4c6", + "d7c6", + "f3e5", + "c6b5", + "f1e1", + "f8e8", + "c3c4", + "e8e5", + "c4b5", + "a6b5", + "b1c3", + "b5b4", + "c3d5", + "f6d5", + "e4d5", + "e5e1", + "d1e1", + "d8d5", + "e1e3", + "b7b6", + "b2b3", + "c7c6", + "e3e2", + "b6b5", + "a1d1", + "d5d4", + "e2e7", + "a8a2", + "e7e8", + "g8h7", + "e8f7", + "c5c4", + "f7f5", + "h7h8", + "f5f8", + "h8h7", + "f8f5", + "h7h8", + "f5f8", + "h8h7", + "f8f5", + "h7h8", + "f5f8", + "h8h7", + ] + + self.assertEqual(game.moves_uci(), moves_reference) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Draw") + self.assertFalse(game.is_checkmate) + self.assertFalse(game.is_stalemate) + self.assertFalse(game.is_game_over) + + # Comments: "asdf" on move 0, "hello" on move 6, rest should be non-None + # but empty (comments are enabled so placeholders exist) + comments = game.comments + self.assertEqual(len(comments), 64) + self.assertIn("asdf", comments[0]) + self.assertIn("hello", comments[6]) + + def test_full_game_with_headers(self): + """Test a full game with many headers, verifying headers and moves.""" + pgn = """[Event "Rated Classical game"] +[Site "https://lichess.org/lhy6ehiv"] +[White "goerch"] +[Black "niltonrosao001"] +[Result "0-1"] +[UTCDate "2013.06.30"] +[UTCTime "22:10:02"] +[WhiteElo "1702"] +[BlackElo "2011"] +[WhiteRatingDiff "-3"] +[BlackRatingDiff "+5"] +[ECO "A46"] +[Opening "Indian Game: Spielmann-Indian"] +[TimeControl "600+8"] +[Termination "Normal"] + +1. d4 Nf6 2. Nf3 c5 3. e3 b6 4. Nc3 e6 5. Bb5 a6 6. Bd3 Bb7 7. O-O b5 8. b3 d5 9. Bb2 Nbd7 10. a4 b4 11. Ne2 Bd6 12. c4 bxc3 13. Bxc3 O-O 14. Ng3 Rc8 15. dxc5 Nxc5 16. Nd4 Nxd3 17. Qxd3 Qb6 18. Rab1 Bb4 19. Nge2 Ne4 20. Rfc1 Nxc3 21. Nxc3 Rc7 22. Na2 Rfc8 23. Rc2 g6 24. Nxb4 Qxb4 25. Rbc1 Rxc2 26. Nxc2 Qc3 27. Qxc3 Rxc3 28. Kf1 d4 29. exd4 Be4 30. Ke1 Rxc2 31. Rxc2 Bxc2 32. Kd2 Bxb3 33. a5 Bd5 0-1 +""" + + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] + + moves_reference = [ + "d2d4", + "g8f6", + "g1f3", + "c7c5", + "e2e3", + "b7b6", + "b1c3", + "e7e6", + "f1b5", + "a7a6", + "b5d3", + "c8b7", + "e1g1", + "b6b5", + "b2b3", + "d7d5", + "c1b2", + "b8d7", + "a2a4", + "b5b4", + "c3e2", + "f8d6", + "c2c4", + "b4c3", + "b2c3", + "e8g8", + "e2g3", + "a8c8", + "d4c5", + "d7c5", + "f3d4", + "c5d3", + "d1d3", + "d8b6", + "a1b1", + "d6b4", + "g3e2", + "f6e4", + "f1c1", + "e4c3", + "e2c3", + "c8c7", + "c3a2", + "f8c8", + "c1c2", + "g7g6", + "a2b4", + "b6b4", + "b1c1", + "c7c2", + "d4c2", + "b4c3", + "d3c3", + "c8c3", + "g1f1", + "d5d4", + "e3d4", + "b7e4", + "f1e1", + "c3c2", + "c1c2", + "e4c2", + "e1d2", + "c2b3", + "a4a5", + "b3d5", + ] + + self.assertEqual(game.moves_uci(), moves_reference) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Black") + + # Verify headers + self.assertEqual(game.headers["Event"], "Rated Classical game") + self.assertEqual(game.headers["Site"], "https://lichess.org/lhy6ehiv") + self.assertEqual(game.headers["White"], "goerch") + self.assertEqual(game.headers["Black"], "niltonrosao001") + self.assertEqual(game.headers["Result"], "0-1") + self.assertEqual(game.headers["WhiteElo"], "1702") + self.assertEqual(game.headers["BlackElo"], "2011") + self.assertEqual(game.headers["ECO"], "A46") + self.assertEqual(game.headers["Opening"], "Indian Game: Spielmann-Indian") + self.assertEqual(game.headers["TimeControl"], "600+8") + self.assertEqual(game.headers["Termination"], "Normal") + + def test_annotated_game_all_evals_and_clocks(self): + """Test a fully annotated game with eval and clock on every move.""" + pgn = """1. e4 { [%eval 0.17] [%clk 0:00:30] } 1... c5 { [%eval 0.19] [%clk 0:00:30] } +2. Nf3 { [%eval 0.25] [%clk 0:00:29] } 2... Nc6 { [%eval 0.33] [%clk 0:00:30] } +3. Bc4 { [%eval -0.13] [%clk 0:00:28] } 3... e6 { [%eval -0.04] [%clk 0:00:30] } +4. c3 { [%eval -0.4] [%clk 0:00:27] } 4... b5 { [%eval 1.18] [%clk 0:00:30] } +5. Bb3 { [%eval 0.21] [%clk 0:00:26] } 5... c4 { [%eval 0.32] [%clk 0:00:29] } +6. Bc2 { [%eval 0.2] [%clk 0:00:25] } 6... a5 { [%eval 0.6] [%clk 0:00:29] } +7. d4 { [%eval 0.29] [%clk 0:00:23] } 7... cxd3 { [%eval 0.6] [%clk 0:00:27] } +8. Qxd3 { [%eval 0.12] [%clk 0:00:22] } 8... Nf6 { [%eval 0.52] [%clk 0:00:26] } +9. e5 { [%eval 0.39] [%clk 0:00:21] } 9... Nd5 { [%eval 0.45] [%clk 0:00:25] } +10. Bg5 { [%eval -0.44] [%clk 0:00:18] } 10... Qc7 { [%eval -0.12] [%clk 0:00:23] } +11. Nbd2 { [%eval -3.15] [%clk 0:00:14] } 11... h6 { [%eval -2.99] [%clk 0:00:23] } +12. Bh4 { [%eval -3.0] [%clk 0:00:11] } 12... Ba6 { [%eval -0.12] [%clk 0:00:23] } +13. b3 { [%eval -4.14] [%clk 0:00:02] } 13... Nf4 { [%eval -2.73] [%clk 0:00:21] } 0-1""" + + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] + + # Verify move count + self.assertEqual(len(game), 26) + self.assertTrue(game.is_valid) + self.assertEqual(game.outcome, "Black") + + # Verify all 26 evals + evals_reference = [ + 0.17, + 0.19, + 0.25, + 0.33, + -0.13, + -0.04, + -0.4, + 1.18, + 0.21, + 0.32, + 0.2, + 0.6, + 0.29, + 0.6, + 0.12, + 0.52, + 0.39, + 0.45, + -0.44, + -0.12, + -3.15, + -2.99, + -3.0, + -0.12, + -4.14, + -2.73, + ] + for i, expected_eval in enumerate(evals_reference): + self.assertAlmostEqual( + game.evals[i], expected_eval, places=2, msg=f"Eval mismatch at move {i}" + ) + + # Verify all 26 clocks (stored as seconds) + clock_seconds_reference = [ + 30, + 30, + 29, + 30, + 28, + 30, + 27, + 30, + 26, + 29, + 25, + 29, + 23, + 27, + 22, + 26, + 21, + 25, + 18, + 23, + 14, + 23, + 11, + 23, + 2, + 21, + ] + for i, expected_seconds in enumerate(clock_seconds_reference): + self.assertAlmostEqual( + game.clocks[i], + float(expected_seconds), + places=1, + msg=f"Clock mismatch at move {i}", + ) + + def test_castling_rights_through_game(self): + """Test castling rights through multiple positions including king and rook moves.""" + pgn = "1. e4 e5 2. Bc4 c6 3. Nf3 d6 4. Rg1 f6 5. Rh1 g6 6. Ke2 b6 7. Ke1 g5 1-0" + + result = rust_pgn_reader_python_binding.parse_game(pgn) + game = result[0] + + # 14 moves + 1 initial = 15 positions + self.assertEqual(game.num_positions, 15) + + # Castling order: [K, Q, k, q] + # Positions 0-6 (initial through 3. Nf3): all castling rights intact + for pos in range(7): + self.assertTrue( + all(game.castling[pos]), + f"Position {pos}: all castling should be intact", + ) + + # Position 7 (after 4. Rg1): white kingside lost + self.assertFalse(game.castling[7, 0]) # White K gone + self.assertTrue(game.castling[7, 1]) # White Q intact + self.assertTrue(game.castling[7, 2]) # Black k intact + self.assertTrue(game.castling[7, 3]) # Black q intact + + # Positions 8-10: same (Rh1 restores rook but not castling rights) + for pos in [8, 9, 10]: + self.assertFalse(game.castling[pos, 0]) # White K still gone + self.assertTrue(game.castling[pos, 1]) # White Q intact + + # Position 11 (after 6. Ke2): white both castling lost + self.assertFalse(game.castling[11, 0]) # White K + self.assertFalse(game.castling[11, 1]) # White Q + self.assertTrue(game.castling[11, 2]) # Black k + self.assertTrue(game.castling[11, 3]) # Black q + + # Position 12-14: white castling remains lost (Ke1 doesn't restore it) + for pos in [12, 13, 14]: + self.assertFalse(game.castling[pos, 0]) + self.assertFalse(game.castling[pos, 1]) + + def test_parse_games_from_strings_with_status(self): + """Test parse_games_from_strings with checkmate, move verification, and status.""" + pgns = [ + # Game 0: ends in checkmate + "1. Nf3 g6 2. b3 Bg7 3. Nc3 e5 4. Bb2 e4 5. Ng1 d6 6. Rb1 a5 7. Nxe4 Bxb2 8. Rxb2 Nf6 9. Nxf6+ Qxf6 10. Rb1 Ra6 11. e3 Rb6 12. d4 a4 13. bxa4 Nc6 14. Rxb6 cxb6 15. Bb5 Bd7 16. Bxc6 Bxc6 17. Nf3 Bxa4 18. O-O O-O 19. Re1 Rc8 20. Re2 d5 21. Ne5 Qf5 22. Qd3 Bxc2 23. Qxf5 Bxf5 24. h3 b5 25. Rb2 f6 26. Ng4 Rc6 27. Nh6+ Kg7 28. Nxf5+ gxf5 29. Rxb5 Rc7 30. Rxd5 Kg6 31. f4 Kh5 32. Rxf5+ Kh4 33. Rxf6 Kg3 34. d5 Rc1# 0-1", + # Game 1: no checkmate, game abandoned mid-play + "1. e4 e5 2. Nf3 Nc6 3. Bb5 Nf6 4. O-O Bc5 5. d3 d6 6. h3 h6 7. c3 O-O 1-0", + ] + result = rust_pgn_reader_python_binding.parse_games_from_strings(pgns) + + self.assertEqual(result.num_games, 2) + + # Game 0: checkmate + game0 = result[0] + self.assertTrue(game0.is_valid) + self.assertTrue(game0.is_checkmate) + self.assertFalse(game0.is_stalemate) + self.assertTrue(game0.is_game_over) + self.assertEqual(game0.legal_move_count, 0) + self.assertEqual(game0.outcome, "Black") + # Verify some key moves + moves0 = game0.moves_uci() + self.assertEqual(len(moves0), 68) + self.assertEqual(moves0[0], "g1f3") + self.assertEqual(moves0[-1], "c7c1") + + # Game 1: no checkmate + game1 = result[1] + self.assertTrue(game1.is_valid) + self.assertFalse(game1.is_checkmate) + self.assertFalse(game1.is_stalemate) + self.assertFalse(game1.is_game_over) + self.assertEqual(game1.outcome, "White") + moves1 = game1.moves_uci() + self.assertEqual(len(moves1), 14) + self.assertEqual(moves1[0], "e2e4") + if __name__ == "__main__": unittest.main() diff --git a/src/visitor.rs b/src/visitor.rs deleted file mode 100644 index 353a1af..0000000 --- a/src/visitor.rs +++ /dev/null @@ -1,447 +0,0 @@ -use crate::board_serialization::{ - get_castling_rights, get_en_passant_file, get_halfmove_clock, get_turn, serialize_board, -}; -use crate::comment_parsing::{parse_comments, CommentContent, ParsedTag}; -use crate::python_bindings::{PositionStatus, PyUciMove}; -use pgn_reader::{KnownOutcome, Outcome, RawComment, RawTag, SanPlus, Skip, Visitor}; -use pyo3::prelude::*; -use shakmaty::{fen::Fen, uci::UciMove, CastlingMode, Chess, Color, Position}; -use std::ops::ControlFlow; - -/// Board state tracking for flat output. -/// Only allocated when store_board_states is true to avoid overhead in the common case. -#[derive(Default)] -pub struct BoardStateData { - pub board_states: Vec, // Flattened: 64 bytes per position - pub en_passant_states: Vec, // Per position: -1 or file 0-7 - pub halfmove_clocks: Vec, // Per position - pub turn_states: Vec, // Per position: true=white - pub castling_states: Vec, // Flattened: 4 bools per position [K,Q,k,q] -} - -impl BoardStateData { - pub fn with_capacity(estimated_positions: usize) -> Self { - BoardStateData { - board_states: Vec::with_capacity(estimated_positions * 64), - en_passant_states: Vec::with_capacity(estimated_positions), - halfmove_clocks: Vec::with_capacity(estimated_positions), - turn_states: Vec::with_capacity(estimated_positions), - castling_states: Vec::with_capacity(estimated_positions * 4), - } - } - - pub fn clear(&mut self) { - self.board_states.clear(); - self.en_passant_states.clear(); - self.halfmove_clocks.clear(); - self.turn_states.clear(); - self.castling_states.clear(); - } -} - -#[pyclass] -/// A Visitor to extract SAN moves and comments from PGN movetext -pub struct MoveExtractor { - #[pyo3(get)] - pub moves: Vec, - - pub store_legal_moves: bool, - pub flat_legal_moves: Vec, - pub legal_moves_offsets: Vec, - - #[pyo3(get)] - pub valid_moves: bool, - - #[pyo3(get)] - pub comments: Vec>, - - #[pyo3(get)] - pub evals: Vec>, - - #[pyo3(get)] - pub clock_times: Vec>, - - #[pyo3(get)] - pub outcome: Option, - - #[pyo3(get)] - pub headers: Vec<(String, String)>, - - #[pyo3(get)] - pub castling_rights: Vec>, - - #[pyo3(get)] - pub position_status: Option, - - pub pos: Chess, - - // Board state tracking for flat output (not directly exposed to Python) - // Only allocated if store_board_states is true to avoid overhead - pub board_state_data: Option>, -} - -#[pymethods] -impl MoveExtractor { - #[new] - #[pyo3(signature = (store_legal_moves = false, store_board_states = false))] - pub fn new(store_legal_moves: bool, store_board_states: bool) -> MoveExtractor { - MoveExtractor { - moves: Vec::with_capacity(100), - store_legal_moves, - flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), - legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), - pos: Chess::default(), - valid_moves: true, - comments: Vec::with_capacity(100), - evals: Vec::with_capacity(100), - clock_times: Vec::with_capacity(100), - outcome: None, - headers: Vec::with_capacity(10), - castling_rights: Vec::with_capacity(100), - position_status: None, - // Only allocate board state data when needed to avoid overhead - board_state_data: if store_board_states { - Some(Box::new(BoardStateData::with_capacity(100))) - } else { - None - }, - } - } - - /// Check if board states are being stored. - pub fn stores_board_states(&self) -> bool { - self.board_state_data.is_some() - } - - fn turn(&self) -> bool { - match self.pos.turn() { - Color::White => true, - Color::Black => false, - } - } - - fn push_castling_bitboards(&mut self) { - let castling_bitboard = self.pos.castles().castling_rights(); - let castling_rights = ( - castling_bitboard.contains(shakmaty::Square::A1), - castling_bitboard.contains(shakmaty::Square::H1), - castling_bitboard.contains(shakmaty::Square::A8), - castling_bitboard.contains(shakmaty::Square::H8), - ); - - self.castling_rights.push(Some(castling_rights)); - } - - fn push_legal_moves(&mut self) { - // Record the starting offset for the current position's legal moves. - self.legal_moves_offsets.push(self.flat_legal_moves.len()); - - let legal_moves_for_pos = self.pos.legal_moves(); - self.flat_legal_moves.reserve(legal_moves_for_pos.len()); - - for m in legal_moves_for_pos { - let uci_move_obj = UciMove::from_standard(m); - if let UciMove::Normal { - from, - to, - promotion: promo_opt, - } = uci_move_obj - { - self.flat_legal_moves.push(PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }); - } - } - } - - /// Record current board state to flat arrays for ParsedGames output. - fn push_board_state(&mut self) { - if let Some(ref mut data) = self.board_state_data { - data.board_states - .extend_from_slice(&serialize_board(&self.pos)); - data.en_passant_states.push(get_en_passant_file(&self.pos)); - data.halfmove_clocks.push(get_halfmove_clock(&self.pos)); - data.turn_states.push(get_turn(&self.pos)); - let castling = get_castling_rights(&self.pos); - data.castling_states.extend_from_slice(&castling); - } - } - - fn update_position_status(&mut self) { - // TODO this checks legal_moves() a bunch of times - self.position_status = Some(PositionStatus { - is_checkmate: self.pos.is_checkmate(), - is_stalemate: self.pos.is_stalemate(), - legal_move_count: self.pos.legal_moves().len(), - is_game_over: self.pos.is_game_over(), - insufficient_material: ( - self.pos.has_insufficient_material(Color::White), - self.pos.has_insufficient_material(Color::Black), - ), - turn: match self.pos.turn() { - Color::White => true, - Color::Black => false, - }, - }); - } - - #[getter] - fn legal_moves(&self) -> Vec> { - let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); - if self.legal_moves_offsets.is_empty() { - return result; - } - - for i in 0..self.legal_moves_offsets.len() - 1 { - let start = self.legal_moves_offsets[i]; - let end = self.legal_moves_offsets[i + 1]; - result.push(self.flat_legal_moves[start..end].to_vec()); - } - - // Handle the last chunk - if let Some(&start) = self.legal_moves_offsets.last() { - result.push(self.flat_legal_moves[start..].to_vec()); - } - - result - } -} - -impl Visitor for MoveExtractor { - type Tags = Vec<(String, String)>; - type Movetext = (); - type Output = bool; - - fn begin_tags(&mut self) -> ControlFlow { - self.headers.clear(); - ControlFlow::Continue(Vec::with_capacity(10)) - } - - fn tag( - &mut self, - tags: &mut Self::Tags, - key: &[u8], - value: RawTag<'_>, - ) -> ControlFlow { - let key_str = String::from_utf8_lossy(key).into_owned(); - let value_str = String::from_utf8_lossy(value.as_bytes()).into_owned(); - tags.push((key_str, value_str)); - ControlFlow::Continue(()) - } - - fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { - self.headers = tags; - self.moves.clear(); - self.flat_legal_moves.clear(); - self.legal_moves_offsets.clear(); - self.valid_moves = true; - self.comments.clear(); - self.evals.clear(); - self.clock_times.clear(); - self.castling_rights.clear(); - if let Some(ref mut data) = self.board_state_data { - data.clear(); - } - - // Determine castling mode from Variant header (case-insensitive) - let castling_mode = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) - .and_then(|(_, v)| { - let v_lower = v.to_lowercase(); - if v_lower == "chess960" { - Some(CastlingMode::Chess960) - } else { - None - } - }) - .unwrap_or(CastlingMode::Standard); - - // Try to parse FEN from headers, fall back to default position - let fen_header = self - .headers - .iter() - .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) - .map(|(_, v)| v.as_str()); - - if let Some(fen_str) = fen_header { - match fen_str.parse::() { - Ok(fen) => match fen.into_position(castling_mode) { - Ok(pos) => self.pos = pos, - Err(e) => { - eprintln!("invalid FEN position: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - }, - Err(e) => { - eprintln!("failed to parse FEN: {}", e); - self.pos = Chess::default(); - self.valid_moves = false; - } - } - } else { - self.pos = Chess::default(); - } - - self.push_castling_bitboards(); - if self.store_legal_moves { - self.push_legal_moves(); - } - // Record initial board state for flat output (only if enabled) - if self.board_state_data.is_some() { - self.push_board_state(); - } - ControlFlow::Continue(()) - } - - // Roughly half the time during parsing is spent here in san() - fn san( - &mut self, - _movetext: &mut Self::Movetext, - san_plus: SanPlus, - ) -> ControlFlow { - if self.valid_moves { - // Most of the function time is spent calculating to_move() - match san_plus.san.to_move(&self.pos) { - Ok(m) => { - self.pos.play_unchecked(m); - if self.store_legal_moves { - self.push_legal_moves(); - } - // Record board state after move for flat output (only if enabled) - if self.board_state_data.is_some() { - self.push_board_state(); - } - let uci_move_obj = UciMove::from_standard(m); - - match uci_move_obj { - UciMove::Normal { - from, - to, - promotion: promo_opt, - } => { - let py_uci_move = PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }; - self.moves.push(py_uci_move); - self.push_castling_bitboards(); - - // Push placeholders to keep vectors in sync - self.comments.push(None); - self.evals.push(None); - self.clock_times.push(None); - } - _ => { - // This case handles UciMove::Put and UciMove::Null, - // which are not expected from standard PGN moves - // that PyUciMove is designed to represent. - eprintln!( - "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", - uci_move_obj - ); - self.valid_moves = false; - } - } - } - Err(err) => { - eprintln!("error in game: {} {}", err, san_plus); - self.valid_moves = false; - } - } - } - ControlFlow::Continue(()) - } - - fn comment( - &mut self, - _movetext: &mut Self::Movetext, - _comment: RawComment<'_>, - ) -> ControlFlow { - match parse_comments(_comment.as_bytes()) { - Ok((remaining_input, parsed_comments)) => { - if !remaining_input.is_empty() { - eprintln!("Unparsed remaining input: {:?}", remaining_input); - return ControlFlow::Continue(()); - } - - let mut move_comments = String::new(); - - for content in parsed_comments { - match content { - CommentContent::Text(text) => { - if !text.trim().is_empty() { - if !move_comments.is_empty() { - move_comments.push(' '); - } - move_comments.push_str(&text); - } - } - CommentContent::Tag(tag_content) => match tag_content { - ParsedTag::Eval(eval_value) => { - if let Some(last_eval) = self.evals.last_mut() { - *last_eval = Some(eval_value); - } - } - ParsedTag::Mate(mate_value) => { - if !move_comments.is_empty() && !move_comments.ends_with(' ') { - move_comments.push(' '); - } - move_comments.push_str(&format!("[Mate {}]", mate_value)); - } - ParsedTag::ClkTime { - hours, - minutes, - seconds, - } => { - if let Some(last_clk) = self.clock_times.last_mut() { - *last_clk = Some((hours, minutes, seconds)); - } - } - }, - } - } - - if let Some(last_comment) = self.comments.last_mut() { - *last_comment = Some(move_comments); - } - } - Err(e) => { - eprintln!("Error parsing comment: {:?}", e); - } - } - ControlFlow::Continue(()) - } - - fn begin_variation( - &mut self, - _movetext: &mut Self::Movetext, - ) -> ControlFlow { - ControlFlow::Continue(Skip(true)) // stay in the mainline - } - - fn outcome( - &mut self, - _movetext: &mut Self::Movetext, - _outcome: Outcome, - ) -> ControlFlow { - self.outcome = Some(match _outcome { - Outcome::Known(known) => match known { - KnownOutcome::Decisive { winner } => format!("{:?}", winner), - KnownOutcome::Draw => "Draw".to_string(), - }, - Outcome::Unknown => "Unknown".to_string(), - }); - self.update_position_status(); - ControlFlow::Continue(()) - } - - fn end_game(&mut self, _movetext: Self::Movetext) -> Self::Output { - self.valid_moves - } -} From e2a4ab7e5268010d19d8a90e6f1a05d1f09d8b63 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 11:30:24 -0500 Subject: [PATCH 27/31] More tests --- src/flat_visitor.rs | 121 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/src/flat_visitor.rs b/src/flat_visitor.rs index 823b352..b78fa0a 100644 --- a/src/flat_visitor.rs +++ b/src/flat_visitor.rs @@ -722,4 +722,125 @@ mod tests { // Total legal moves stored assert_eq!(buffers.legal_move_from_squares.len(), 40); } + + #[test] + fn test_parse_game_without_headers() { + let pgn = "1. Nf3 d5 2. e4 c5 3. exd5 e5 4. dxe6 0-1"; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.num_games(), 1); + assert_eq!(buffers.total_moves(), 7); + assert_eq!(buffers.outcome[0], Some("Black".to_string())); + } + + #[test] + fn test_parse_game_with_standard_fen() { + // A game starting from a mid-game position + let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 a6 4. Ba4 Nf6 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap()); // valid game + assert_eq!(buffers.total_moves(), 4); + } + + #[test] + fn test_parse_chess960_game() { + // Chess960 game with custom starting position + let pgn = r#"[Variant "chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + result.unwrap(), + "Chess960 moves should be valid with proper FEN" + ); + assert_eq!(buffers.total_moves(), 6); + } + + #[test] + fn test_parse_chess960_variant_case_insensitive() { + // Test that variant detection is case-insensitive + let pgn = r#"[Variant "Chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap(), "Should handle Chess960 case variations"); + } + + #[test] + fn test_parse_invalid_fen_falls_back() { + // Invalid FEN should fall back to default and mark invalid + let pgn = r#"[FEN "invalid fen string"] + +1. e4 e5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + !result.unwrap(), + "Should mark as invalid when FEN parsing fails" + ); + } + + #[test] + fn test_fen_header_case_insensitive() { + // FEN header key should be case-insensitive + let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!(result.unwrap(), "Should handle lowercase 'fen' header"); + } + + #[test] + fn test_parse_game_with_custom_fen_no_variant() { + // Standard chess from a mid-game position (no Variant header) + // Position after 1.e4 e5 2.Nf3 Nc6 3.Bb5 (Ruy Lopez) + let pgn = r#"[Event "Test Game"] +[FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] + +3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; + + let config = default_config(); + let mut buffers = Buffers::with_capacity(1, 70, &config); + let result = parse_game_to_buffers(pgn, &mut buffers, &config); + + assert!(result.is_ok()); + assert!( + result.unwrap(), + "Standard game with custom FEN should be valid" + ); + assert_eq!(buffers.total_moves(), 5); // a6, Ba4, Nf6, O-O, Be7 + } } From e645ffa617c7c0087f8a8013e9e77bc08f69ddb3 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 11:30:39 -0500 Subject: [PATCH 28/31] Update readme, simplify benchmark --- README.md | 11 +- src/bench_parse_games.py | 297 +++++---------------------------------- 2 files changed, 46 insertions(+), 262 deletions(-) diff --git a/README.md b/README.md index aa9fe12..df0012a 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,21 @@ Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on a |----------------------------------------------------------------------------|-------------|--------| | [rust-pgn-reader](https://github.com/niklasf/rust-pgn-reader/tree/master) | PGN | 1s | | rust_pgn_reader_python_binding, parse_games (multithreaded) | parquet | 0.35s | +| rust_pgn_reader_python_binding, parse_games_from_strings (multithreaded) | parquet | 0.5s | +| rust_pgn_reader_python_binding, parse_game (single-threaded) | parquet | 3.3s | +| rust_pgn_reader_python_binding, parse_game (single-threaded) | PGN | 4.7s | | [chess-library](https://github.com/Disservin/chess-library) | PGN | 2s | | [python-chess](https://github.com/niklasf/python-chess) | PGN | 3+ min | To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: -`python src/bench_parse_games.py 2013-07-train-00000-of-00001.parquet` +`python src/bench_parquet.py` (single-threaded parse_game from parquet) + +`python src/bench_parquet_parallel.py` (multithreaded parse_games_from_strings) + +`python src/bench_parquet_arrow.py` (multithreaded parse_games from Arrow) + +`python src/bench_parse_games.py 2013-07-train-00000-of-00001.parquet` (full benchmark with data access) ## Building `maturin develop` diff --git a/src/bench_parse_games.py b/src/bench_parse_games.py index eecaa5e..cafebe2 100644 --- a/src/bench_parse_games.py +++ b/src/bench_parse_games.py @@ -1,292 +1,73 @@ """ -Benchmark for parse_games() PGN parsing. - -This benchmark measures: -1. Parsing speed (games/second) -2. Memory efficiency (bytes per position) -3. Data access patterns for ML workloads +Benchmark for parse_games() — parsing speed and data access patterns. Usage: - python bench_parse_games.py [parquet_file] - -If no parquet file is provided, synthetic PGN data will be generated. + python bench_parse_games.py 2013-07-train-00000-of-00001.parquet """ import sys import time -import argparse -from typing import Optional import numpy as np -import pyarrow as pa +import pyarrow.parquet as pq import rust_pgn_reader_python_binding as pgn -def generate_synthetic_pgns(num_games: int, moves_per_game: int = 40) -> list[str]: - """Generate synthetic PGN games for benchmarking.""" - move_pairs = [ - ("e4", "e5"), - ("Nf3", "Nc6"), - ("Bb5", "a6"), - ("Ba4", "Nf6"), - ("O-O", "Be7"), - ("Re1", "b5"), - ("Bb3", "d6"), - ("c3", "O-O"), - ("h3", "Nb8"), - ("d4", "Nbd7"), - ("Nbd2", "Bb7"), - ("Bc2", "Re8"), - ("Nf1", "Bf8"), - ("Ng3", "g6"), - ("Bg5", "h6"), - ("Bd2", "Bg7"), - ("a4", "c5"), - ("d5", "c4"), - ("b4", "Nc5"), - ("Be3", "Qc7"), - ] - - pgns = [] - for i in range(num_games): - moves = [] - num_pairs = min(moves_per_game // 2, len(move_pairs)) - for j in range(num_pairs): - white_move, black_move = move_pairs[j] - moves.append(f"{j + 1}. {white_move} {black_move}") - - movetext = " ".join(moves) - result = ["1-0", "0-1", "1/2-1/2"][i % 3] - - pgn_str = f"""[Event "Synthetic Game {i}"] -[White "Player{i * 2}"] -[Black "Player{i * 2 + 1}"] -[Result "{result}"] - -{movetext} {result}""" - pgns.append(pgn_str) - - return pgns - +def main(): + if len(sys.argv) < 2: + print("Usage: python bench_parse_games.py ") + return 1 -def load_parquet_pgns(file_path: str, limit: Optional[int] = None) -> pa.ChunkedArray: - """Load PGN strings from a parquet file.""" - import pyarrow.parquet as pq + file_path = sys.argv[1] + print(f"Loading from: {file_path}") pf = pq.ParquetFile(file_path) + chunked_array = pf.read(columns=["movetext"]).column("movetext") + print(f"Loaded {len(chunked_array):,} games") - table = pf.read() - for col_name in ["movetext", "pgn", "moves", "game"]: - if col_name in table.column_names: - arr = table.column(col_name) - if limit: - arr = arr.slice(0, limit) - return arr - - raise ValueError( - f"Could not find PGN column. Available columns: {table.column_names}" - ) - - -def benchmark_parse_games( - chunked_array: pa.ChunkedArray, num_threads: Optional[int] = None, warmup: int = 1 -) -> dict: - """Benchmark parse_games().""" # Warmup - for _ in range(warmup): - _ = pgn.parse_games(chunked_array, num_threads=num_threads) + _ = pgn.parse_games(chunked_array) - # Timed run + # Timed parse start = time.perf_counter() - result = pgn.parse_games(chunked_array, num_threads=num_threads) + result = pgn.parse_games(chunked_array) elapsed = time.perf_counter() - start - return { - "method": "parse_games", - "elapsed_seconds": elapsed, - "num_games": result.num_games, - "num_moves": result.num_moves, - "num_positions": result.num_positions, - "games_per_second": result.num_games / elapsed, - "moves_per_second": result.num_moves / elapsed, - "positions_per_second": result.num_positions / elapsed, - "valid_games": int(sum(chunk.valid.sum() for chunk in result.chunks)), - "result": result, - } - + print(f"\nParsing: {elapsed:.3f}s") + print( + f" {result.num_games:,} games, {result.num_moves:,} moves, {result.num_positions:,} positions" + ) + print(f" {result.num_games / elapsed:,.0f} games/sec") + print(f" {result.num_chunks} chunks") -def benchmark_data_access(result) -> dict: - """Benchmark data access patterns for parse_games result.""" + # Data access: chunk-level array access start = time.perf_counter() - - # Simulate ML data loading: access all boards via chunks for chunk in result.chunks: _ = chunk.boards.sum() - - # Access moves via chunks - for chunk in result.chunks: _ = chunk.from_squares.sum() _ = chunk.to_squares.sum() - elapsed = time.perf_counter() - start - return {"access_time": elapsed} + print(f"\nChunk array access: {elapsed:.3f}s") - -def benchmark_per_game_access(result) -> dict: - """Benchmark per-game access pattern.""" + # Data access: per-game views + n_access = min(1000, result.num_games) start = time.perf_counter() - - for i in range(min(1000, result.num_games)): - game = result[i] - _ = game.boards - + for i in range(n_access): + _ = result[i].boards elapsed = time.perf_counter() - start - return { - "access_time": elapsed, - "games_accessed": min(1000, result.num_games), - } - - -def benchmark_position_to_game_mapping(result) -> dict: - """Benchmark position-to-game mapping.""" - start = time.perf_counter() + print( + f"Per-game access ({n_access} games): {elapsed:.3f}s ({elapsed / n_access * 1000:.3f}ms/game)" + ) + # Position-to-game mapping indices = np.random.randint(0, result.num_positions, size=1000, dtype=np.int64) + start = time.perf_counter() _ = result.position_to_game(indices) - elapsed = time.perf_counter() - start - return { - "access_time": elapsed, - "positions_accessed": 1000, - } - + print(f"Position-to-game (1000 lookups): {elapsed * 1000:.3f}ms") -def format_number(n: float) -> str: - """Format large numbers with K/M suffix.""" - if n >= 1_000_000: - return f"{n / 1_000_000:.2f}M" - elif n >= 1_000: - return f"{n / 1_000:.2f}K" - else: - return f"{n:.2f}" - - -def print_results(results: dict, label: str): - """Print benchmark results.""" - print(f"\n{label}") - print("-" * 50) - print(f" Time: {results['elapsed_seconds']:.3f}s") - print(f" Games: {results['num_games']:,}") - print(f" Moves: {results['num_moves']:,}") - print(f" Positions: {results['num_positions']:,}") - print(f" Valid games: {results['valid_games']:,}") - print(f" Games/sec: {format_number(results['games_per_second'])}") - print(f" Moves/sec: {format_number(results['moves_per_second'])}") - print(f" Positions/sec: {format_number(results['positions_per_second'])}") - - -def main(): - parser = argparse.ArgumentParser(description="Benchmark parse_games()") - parser.add_argument( - "parquet_file", - nargs="?", - help="Path to parquet file with PGN data (optional)", - ) - parser.add_argument( - "--num-games", - type=int, - default=10000, - help="Number of synthetic games to generate if no parquet file", - ) - parser.add_argument( - "--limit", - type=int, - default=None, - help="Limit number of games from parquet file", - ) - parser.add_argument( - "--threads", - type=int, - default=None, - help="Number of threads (default: all cores)", - ) - parser.add_argument( - "--warmup", - type=int, - default=1, - help="Number of warmup iterations", - ) - parser.add_argument( - "--skip-access", - action="store_true", - help="Skip data access benchmarks", - ) - args = parser.parse_args() - - print("=" * 60) - print("parse_games() Benchmark") - print("=" * 60) - - # Load or generate data - if args.parquet_file: - print(f"\nLoading from: {args.parquet_file}") - try: - chunked_array = load_parquet_pgns(args.parquet_file, limit=args.limit) - print(f"Loaded {len(chunked_array):,} games") - except Exception as e: - print(f"Error loading parquet: {e}") - return 1 - else: - print(f"\nGenerating {args.num_games:,} synthetic games...") - pgns = generate_synthetic_pgns(args.num_games) - chunked_array = pa.chunked_array([pa.array(pgns)]) - print(f"Generated {len(chunked_array):,} games") - - print(f"Threads: {args.threads or 'all cores'}") - print(f"Warmup iterations: {args.warmup}") - - # Benchmark - print("\n" + "=" * 60) - print("PARSING BENCHMARK") - print("=" * 60) - - results = benchmark_parse_games( - chunked_array, num_threads=args.threads, warmup=args.warmup - ) - print_results(results, "parse_games()") - - # Data access benchmarks - if not args.skip_access: - print("\n" + "=" * 60) - print("DATA ACCESS BENCHMARKS") - print("=" * 60) - - data_access = benchmark_data_access(results["result"]) - print(f"\nArray access time: {data_access['access_time']:.3f}s") - - per_game_access = benchmark_per_game_access(results["result"]) - print(f"\nPer-game access time: {per_game_access['access_time']:.3f}s") - print(f"Games accessed: {per_game_access['games_accessed']:,}") - print( - f"Time per game: {per_game_access['access_time'] / per_game_access['games_accessed'] * 1000:.3f}ms" - ) - - position_mapping = benchmark_position_to_game_mapping(results["result"]) - print(f"\nPosition-to-game mapping: {position_mapping['access_time']:.3f}s") - print( - f"Positions mapped: {position_mapping['positions_accessed']:,}" - ) - print( - f"Time per lookup: {position_mapping['access_time'] / position_mapping['positions_accessed'] * 1000000:.3f}us" - ) - - # Memory usage (approximate) - print("\n" + "=" * 60) - print("MEMORY USAGE (approximate)") - print("=" * 60) - - result = results["result"] + # Memory usage total_bytes = 0 for chunk in result.chunks: total_bytes += ( @@ -303,15 +84,9 @@ def main(): + chunk.move_offsets.nbytes + chunk.position_offsets.nbytes ) - - print(f"\nArrays total: {total_bytes / 1024 / 1024:.2f} MB") - print(f"Bytes per position: {total_bytes / result.num_positions:.1f}") - print(f"Bytes per move: {total_bytes / result.num_moves:.1f}") - print(f"Number of chunks: {result.num_chunks}") - - print("\n" + "=" * 60) - print("Benchmark complete!") - print("=" * 60) + print( + f"\nMemory: {total_bytes / 1024 / 1024:.1f} MB ({total_bytes / result.num_positions:.0f} bytes/position)" + ) return 0 From c2b3ddd1b84f258285b90e330bb6b8d406c925a4 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 11:40:49 -0500 Subject: [PATCH 29/31] Rename bench_parse_games.py -> bench_data_access.py and simplify --- src/{bench_parse_games.py => bench_data_access.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename src/{bench_parse_games.py => bench_data_access.py} (95%) diff --git a/src/bench_parse_games.py b/src/bench_data_access.py similarity index 95% rename from src/bench_parse_games.py rename to src/bench_data_access.py index cafebe2..19b9f68 100644 --- a/src/bench_parse_games.py +++ b/src/bench_data_access.py @@ -2,7 +2,7 @@ Benchmark for parse_games() — parsing speed and data access patterns. Usage: - python bench_parse_games.py 2013-07-train-00000-of-00001.parquet + python bench_data_access.py 2013-07-train-00000-of-00001.parquet """ import sys @@ -16,7 +16,7 @@ def main(): if len(sys.argv) < 2: - print("Usage: python bench_parse_games.py ") + print("Usage: python bench_data_access.py ") return 1 file_path = sys.argv[1] From 9150f6078a177724c4768739c8dfa028f44b2431 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 11:41:14 -0500 Subject: [PATCH 30/31] Rename bench files to match API function names and update README --- README.md | 10 ++++++---- src/{bench_parquet.py => bench_parse_game.py} | 0 src/{bench_pgn_file.py => bench_parse_game_pgn.py} | 0 src/{bench_parquet_arrow.py => bench_parse_games.py} | 0 ...t_parallel.py => bench_parse_games_from_strings.py} | 0 5 files changed, 6 insertions(+), 4 deletions(-) rename src/{bench_parquet.py => bench_parse_game.py} (100%) rename src/{bench_pgn_file.py => bench_parse_game_pgn.py} (100%) rename src/{bench_parquet_arrow.py => bench_parse_games.py} (100%) rename src/{bench_parquet_parallel.py => bench_parse_games_from_strings.py} (100%) diff --git a/README.md b/README.md index df0012a..d4c2261 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,15 @@ Below are some benchmarks on Lichess's 2013-07 chess games (293,459 games) on a To replicate, download `2013-07-train-00000-of-00001.parquet` and then run: -`python src/bench_parquet.py` (single-threaded parse_game from parquet) +`python src/bench_parse_games.py` (recommended — multithreaded parse_games via Arrow) -`python src/bench_parquet_parallel.py` (multithreaded parse_games_from_strings) +`python src/bench_parse_games_from_strings.py` (multithreaded parse_games_from_strings) -`python src/bench_parquet_arrow.py` (multithreaded parse_games from Arrow) +`python src/bench_parse_game.py` (single-threaded parse_game from parquet) -`python src/bench_parse_games.py 2013-07-train-00000-of-00001.parquet` (full benchmark with data access) +`python src/bench_parse_game_pgn.py` (single-threaded parse_game from .pgn file) + +`python src/bench_data_access.py 2013-07-train-00000-of-00001.parquet` (parsing + data access + memory) ## Building `maturin develop` diff --git a/src/bench_parquet.py b/src/bench_parse_game.py similarity index 100% rename from src/bench_parquet.py rename to src/bench_parse_game.py diff --git a/src/bench_pgn_file.py b/src/bench_parse_game_pgn.py similarity index 100% rename from src/bench_pgn_file.py rename to src/bench_parse_game_pgn.py diff --git a/src/bench_parquet_arrow.py b/src/bench_parse_games.py similarity index 100% rename from src/bench_parquet_arrow.py rename to src/bench_parse_games.py diff --git a/src/bench_parquet_parallel.py b/src/bench_parse_games_from_strings.py similarity index 100% rename from src/bench_parquet_parallel.py rename to src/bench_parse_games_from_strings.py From 1e28b3cf5d16f79d976984378261506831e3cc40 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 8 Feb 2026 11:43:25 -0500 Subject: [PATCH 31/31] Remove the last "flat" reference --- src/lib.rs | 4 ++-- src/{flat_visitor.rs => visitor.rs} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/{flat_visitor.rs => visitor.rs} (100%) diff --git a/src/lib.rs b/src/lib.rs index 97d328d..2a6f5a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,11 +7,11 @@ use rayon::prelude::*; mod board_serialization; mod comment_parsing; -mod flat_visitor; mod python_bindings; +mod visitor; -pub use flat_visitor::{Buffers, ParseConfig, parse_game_to_buffers}; use python_bindings::{ChunkData, ParsedGames, ParsedGamesIter, PyChunkView, PyGameView}; +pub use visitor::{Buffers, ParseConfig, parse_game_to_buffers}; /// Parse games from Arrow chunked array into a chunked ParsedGames container. /// diff --git a/src/flat_visitor.rs b/src/visitor.rs similarity index 100% rename from src/flat_visitor.rs rename to src/visitor.rs