diff --git a/simplexity/generative_processes/builder.py b/simplexity/generative_processes/builder.py index 8331c9c1..ddcf1851 100644 --- a/simplexity/generative_processes/builder.py +++ b/simplexity/generative_processes/builder.py @@ -10,6 +10,7 @@ # the problematic imports checker that would crash during AST traversal. import inspect +import random from collections.abc import Callable, Mapping, Sequence from typing import Any, Literal @@ -18,7 +19,11 @@ from simplexity.generative_processes.factored_generative_process import ComponentType, FactoredGenerativeProcess from simplexity.generative_processes.generalized_hidden_markov_model import GeneralizedHiddenMarkovModel +from simplexity.generative_processes.generative_process import GenerativeProcess from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel +from simplexity.generative_processes.independent_factored_generative_process import IndependentFactoredGenerativeProcess +from simplexity.generative_processes.inflated_vocabulary_process import InflatedVocabularyProcess +from simplexity.generative_processes.nonergodic_generative_process import NonErgodicGenerativeProcess from simplexity.generative_processes.structures import ( ConditionalTransitions, FullyConditional, @@ -192,7 +197,16 @@ def build_factored_process( if structure_type == "independent": structure = IndependentStructure() - elif structure_type == "chain": + return IndependentFactoredGenerativeProcess( + component_types=component_types, + transition_matrices=transition_matrices, + normalizing_eigenvectors=normalizing_eigenvectors, + initial_states=initial_states, + structure=structure, + noise_epsilon=noise_epsilon, + ) + + if structure_type == "chain": if "control_maps" not in structure_kwargs: raise ValueError("Missing required argument 'control_maps' for chain structure") structure = SequentialConditional(control_maps=tuple(structure_kwargs["control_maps"]), vocab_sizes=vocab_sizes) @@ -634,3 +648,306 @@ def build_transition_coupled_from_spec( emission_variant_indices_array, emission_control_maps_arrays, ) + + +def _build_components_from_spec( + components: Sequence[dict[str, Any]], + device: str | None = None, +) -> list[GenerativeProcess]: + """Build component GenerativeProcess instances from specifications. + + Args: + components: List of component specs. Each spec has: + - component_type: "hmm", "ghmm", or "factored" + - For hmm/ghmm: process_name, process_params + - For factored: structure_type, spec, and structure-specific params + device: Device placement. + + Returns: + List of built GenerativeProcess instances. + + Raises: + ValueError: If component_type is unknown. + """ + built_components = [] + + for comp_spec in components: + comp_type = comp_spec.get("component_type", "hmm") + + if comp_type == "hmm": + process: GenerativeProcess = build_hidden_markov_model( + process_name=comp_spec["process_name"], + process_params=comp_spec.get("process_params", {}), + device=device, + ) + elif comp_type == "ghmm": + process = build_generalized_hidden_markov_model( + process_name=comp_spec["process_name"], + process_params=comp_spec.get("process_params", {}), + device=device, + ) + elif comp_type == "factored": + factored_kwargs = {k: v for k, v in comp_spec.items() if k not in ("component_type", "vocab_map")} + process = build_factored_process_from_spec(**factored_kwargs) + else: + raise ValueError(f"Unknown component_type: {comp_type}") + + built_components.append(process) + + return built_components + + +def build_nonergodic_process_from_spec( + components: Sequence[dict[str, Any]], + component_weights: Sequence[float], + vocab_maps: Sequence[Sequence[int]] | None = None, + device: str | None = None, +) -> NonErgodicGenerativeProcess: + """Build a nonergodic process from component specifications. + + Creates a NonErgodicGenerativeProcess that composes multiple GenerativeProcess + instances into a truly nonergodic mixture with block diagonal structure. + + Args: + components: List of component specs. Each spec has: + - component_type: "hmm", "ghmm", or "factored" + - For hmm/ghmm: process_name, process_params + - For factored: structure_type, spec, and structure-specific params + - vocab_map: Optional per-component vocab mapping + component_weights: Mixture weights for components (will be normalized). + vocab_maps: Optional global vocab maps (overrides per-component). + device: Device placement. + + Returns: + NonErgodicGenerativeProcess instance. + + Example: + ```yaml + instance: + _target_: simplexity.generative_processes.builder.build_nonergodic_process_from_spec + components: + - component_type: hmm + process_name: mess3 + process_params: {x: 0.15, a: 0.6} + - component_type: ghmm + process_name: tom_quantum + process_params: {alpha: 1.0, beta: 4.0} + - component_type: factored + structure_type: independent + spec: + - component_type: hmm + variants: + - process_name: coin + process_params: {p: 0.5} + component_weights: [0.5, 0.3, 0.2] + vocab_maps: + - [0, 1, 2] + - [0, 1, 2] + - [0, 1] + ``` + + Raises: + ValueError: If component_type is unknown. + """ + built_components = _build_components_from_spec(components, device=device) + + if vocab_maps is None: + inferred_vocab_maps = [] + for comp_spec, process in zip(components, built_components, strict=True): + comp_vocab_map = comp_spec.get("vocab_map", list(range(process.vocab_size))) + inferred_vocab_maps.append(comp_vocab_map) + final_vocab_maps: Sequence[Sequence[int]] = inferred_vocab_maps + else: + final_vocab_maps = vocab_maps + + return NonErgodicGenerativeProcess( + components=built_components, + component_weights=component_weights, + vocab_maps=final_vocab_maps, + device=device, + ) + + +def build_nonergodic_disjoint_vocab( + components: Sequence[dict[str, Any]], + component_weights: Sequence[float], + device: str | None = None, +) -> NonErgodicGenerativeProcess: + """Build a nonergodic process where each component has a fully disjoint alphabet. + + Builds each component once to discover its vocab_size, then assigns + non-overlapping vocab_maps: C0 -> [0..V0-1], C1 -> [V0..V0+V1-1], etc. + + Args: + components: List of component specs (same format as build_nonergodic_process_from_spec). + component_weights: Mixture weights for components. + device: Device placement. + + Returns: + NonErgodicGenerativeProcess with disjoint per-component vocabularies. + """ + built_components = _build_components_from_spec(components, device=device) + + vocab_maps: list[list[int]] = [] + offset = 0 + for c in built_components: + vocab_maps.append(list(range(offset, offset + c.vocab_size))) + offset += c.vocab_size + + return NonErgodicGenerativeProcess( + components=built_components, + component_weights=component_weights, + vocab_maps=vocab_maps, + device=device, + ) + + +def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int) -> list[list[int]]: + """Build vocab maps using the prefix strategy. + + C0 gets [0..V-1]. Ci>0 gets shared [0..n_shared-1] + unique tokens above V. + """ + return [list(range(v))] + [ + list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique)) for i in range(n_components - 1) + ] + + +def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[list[int]]: + """Build vocab maps using the sliding/offset strategy. + + Ci gets [i*offset..i*offset+V-1] where offset = max(1, n_unique). + """ + offset = max(1, n_unique) + return [list(range(i * offset, i * offset + v)) for i in range(n_components)] + + +def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]: + """Build vocab maps by having each component randomly sample V tokens from the global pool. + + The global vocab size is the same as in prefix mode: + V + (n_components - 1) * n_unique. + """ + global_vocab_size = v + (n_components - 1) * n_unique + rng = random.Random(seed) + return [sorted(rng.sample(range(global_vocab_size), v)) for _ in range(n_components)] + + +def build_nonergodic_partial_overlap( + components: Sequence[dict[str, Any]], + component_weights: Sequence[float], + overlap_frac: float = 0.7, + mode: Literal["prefix", "sliding", "random"] = "prefix", + seed: int | None = None, + device: str | None = None, +) -> NonErgodicGenerativeProcess: + """Build a nonergodic process with partially overlapping alphabets. + + Args: + components: List of component specs (same format as build_nonergodic_process_from_spec). + component_weights: Mixture weights for components. + overlap_frac: Fraction of tokens shared between components (0.0 = disjoint, 1.0 = full overlap). + mode: Strategy for assigning vocab maps: + - "prefix": C0 gets [0..V-1], Ci>0 gets shared prefix + unique suffix above V. + - "sliding": Each component's vocab is offset by V * (1 - overlap_frac) from the previous. + - "random": Each component independently samples V tokens from the global pool. + Global pool size matches prefix mode. Requires the ``seed`` parameter. + seed: Random seed for reproducibility. Required when mode="random". + device: Device placement. + + Returns: + NonErgodicGenerativeProcess with partially overlapping vocabularies. + + Raises: + ValueError: If mode is unknown or seed is missing for random mode. + """ + if mode == "random" and seed is None: + raise ValueError("seed is required when mode='random'") + + built_components = _build_components_from_spec(components, device=device) + comp_vocab_sizes = [c.vocab_size for c in built_components] + if len(set(comp_vocab_sizes)) != 1: + raise ValueError(f"All components must have equal vocab_size for partial_overlap, got {comp_vocab_sizes}") + v = comp_vocab_sizes[0] + n_shared = int(v * overlap_frac) + n_unique = v - n_shared + n_components = len(components) + + if mode == "prefix": + vocab_maps = _build_prefix_vocab_maps(n_components, v, n_shared, n_unique) + elif mode == "sliding": + vocab_maps = _build_sliding_vocab_maps(n_components, v, n_unique) + elif mode == "random": + if seed is None: + raise ValueError("seed is required when mode='random'") + vocab_maps = _build_random_vocab_maps(n_components, v, n_unique, seed) + else: + raise ValueError(f"Unknown mode '{mode}'. Must be 'prefix', 'sliding', or 'random'.") + + return NonErgodicGenerativeProcess( + components=built_components, + component_weights=component_weights, + vocab_maps=vocab_maps, + device=device, + ) + + +def build_inflated_process( + base_process: GenerativeProcess, + inflation_factor: int, +) -> InflatedVocabularyProcess: + """Build an inflated vocabulary process wrapping a base process. + + Args: + base_process: Any GenerativeProcess to wrap. + inflation_factor: Number of noise variants per base token (K >= 2). + + Returns: + InflatedVocabularyProcess with vocab_size = K * base_process.vocab_size. + """ + return InflatedVocabularyProcess(base_process, inflation_factor) + + +def build_inflated_process_from_spec( + base_spec: dict[str, Any], + inflation_factor: int, + device: str | None = None, +) -> InflatedVocabularyProcess: + """Build an inflated vocabulary process from a base process specification. + + Args: + base_spec: Specification for the base process. Must include: + - component_type: "hmm", "ghmm", or "factored" + - For hmm/ghmm: process_name, process_params + - For factored: structure_type, spec, and structure-specific params + inflation_factor: Number of noise variants per base token (K >= 2). + device: Device placement. + + Returns: + InflatedVocabularyProcess wrapping the built base process. + + Raises: + ValueError: If component_type is unknown. + """ + comp_type = base_spec.get("component_type", "hmm") + + if comp_type == "hmm": + base_process: GenerativeProcess = build_hidden_markov_model( + process_name=base_spec["process_name"], + process_params=base_spec.get("process_params", {}), + device=device, + noise_epsilon=base_spec.get("noise_epsilon", 0.0), + ) + elif comp_type == "ghmm": + base_process = build_generalized_hidden_markov_model( + process_name=base_spec["process_name"], + process_params=base_spec.get("process_params", {}), + device=device, + noise_epsilon=base_spec.get("noise_epsilon", 0.0), + ) + elif comp_type == "factored": + factored_kwargs = {k: v for k, v in base_spec.items() if k not in ("component_type",)} + base_process = build_factored_process_from_spec(**factored_kwargs) + else: + raise ValueError(f"Unknown base component_type: {comp_type}") + + return InflatedVocabularyProcess(base_process, inflation_factor) diff --git a/simplexity/generative_processes/generator.py b/simplexity/generative_processes/generator.py index 15c0ddf1..32dcd766 100644 --- a/simplexity/generative_processes/generator.py +++ b/simplexity/generative_processes/generator.py @@ -16,6 +16,7 @@ import jax.numpy as jnp from simplexity.generative_processes.generative_process import GenerativeProcess +from simplexity.generative_processes.nonergodic_generative_process import NonErgodicState @eqx.filter_jit @@ -77,16 +78,10 @@ def generate_data_batch_with_full_history( if bos_token is None: # Drop first belief state since it's the initial state before any token - if isinstance(belief_states, tuple): - belief_states = tuple(b[:, 1:, ...] for b in belief_states) - else: - belief_states = belief_states[:, 1:, ...] + belief_states = _slice_belief_states(belief_states, slice(1, None)) input_len = inputs.shape[1] - if isinstance(belief_states, tuple): - belief_states = tuple(b[:, :input_len, ...] for b in belief_states) - else: - belief_states = belief_states[:, :input_len, ...] + belief_states = _slice_belief_states(belief_states, slice(None, input_len)) result = { "belief_states": belief_states, @@ -98,6 +93,31 @@ def generate_data_batch_with_full_history( return result +def _slice_belief_states( + belief_states: jax.Array | tuple[jax.Array, ...] | NonErgodicState, + seq_slice: slice, +) -> jax.Array | tuple[jax.Array, ...] | NonErgodicState: + """Slice belief states along the sequence dimension (axis 1). + + Handles different state representations: + - Plain array: slice directly + - Tuple of arrays: slice each element + - NonErgodicState: slice both component_beliefs and component_states + """ + if isinstance(belief_states, NonErgodicState): + return NonErgodicState( + component_beliefs=belief_states.component_beliefs[:, seq_slice, ...], + component_states=tuple( + tuple(s[:, seq_slice, ...] for s in cs) if isinstance(cs, tuple) else cs[:, seq_slice, ...] + for cs in belief_states.component_states + ), + ) + elif isinstance(belief_states, tuple): + return tuple(b[:, seq_slice, ...] for b in belief_states) + else: + return belief_states[:, seq_slice, ...] + + def _compute_prefix_probabilities( data_generator: GenerativeProcess, initial_states: jax.Array | tuple[jax.Array, ...], diff --git a/simplexity/generative_processes/independent_factored_generative_process.py b/simplexity/generative_processes/independent_factored_generative_process.py index acb83710..adeca737 100644 --- a/simplexity/generative_processes/independent_factored_generative_process.py +++ b/simplexity/generative_processes/independent_factored_generative_process.py @@ -49,6 +49,7 @@ def __init__( initial_states: Sequence[jax.Array], structure: ConditionalStructure, device: str | None = None, + noise_epsilon: float = 0.0, frozen_factor_indices: frozenset[int] = frozenset(), frozen_key: jax.Array | None = None, ) -> None: @@ -63,6 +64,7 @@ def __init__( initial_states: Initial state per factor (shape [S_i]) structure: Conditional structure defining factor interactions device: Device to place arrays on (e.g., "cpu", "gpu") + noise_epsilon: Noisy channel epsilon value frozen_factor_indices: Indices of factors whose sequences are frozen across batch frozen_key: JAX random key for frozen sequence generation. Required if frozen_factor_indices is non-empty. @@ -78,6 +80,7 @@ def __init__( initial_states=initial_states, structure=structure, device=device, + noise_epsilon=noise_epsilon, ) num_factors = len(component_types) diff --git a/simplexity/generative_processes/inflated_vocabulary_process.py b/simplexity/generative_processes/inflated_vocabulary_process.py new file mode 100644 index 00000000..abe54ade --- /dev/null +++ b/simplexity/generative_processes/inflated_vocabulary_process.py @@ -0,0 +1,103 @@ +"""Inflated vocabulary generative process wrapper. + +Wraps any GenerativeProcess by adding a uniform noise dimension to the vocabulary, +increasing vocab size by a multiplicative factor K. The noise dimension is stateless: +state dynamics depend only on the base token. + +Token encoding: inflated_token = noise_prefix * V_base + base_token +- base_token extraction: inflated_token % V_base +- noise_prefix extraction: inflated_token // V_base +""" + +from __future__ import annotations + +import chex +import equinox as eqx +import jax +import jax.numpy as jnp + +from simplexity.generative_processes.generative_process import GenerativeProcess + + +class InflatedVocabularyProcess[State](GenerativeProcess[State]): + """Wraps a GenerativeProcess by adding a uniform noise dimension to inflate the vocabulary. + + For a base process with vocab size V and inflation factor K: + - New vocab size is K * V + - inflated_token = noise_prefix * V + base_token + - P(inflated_token | state) = P(base_token | state) / K + - State dynamics only depend on base_token (noise is stateless) + + This increases optimal per-token loss by exactly log(K) nats. + + Args: + base_process: The generative process to wrap. + inflation_factor: Number of noise variants per base token (K >= 2). + """ + + base_process: GenerativeProcess[State] + inflation_factor: int + _base_vocab_size: int + _inflated_vocab_size: int + + def __init__( + self, + base_process: GenerativeProcess[State], + inflation_factor: int, + ) -> None: + if inflation_factor < 2: + raise ValueError(f"inflation_factor must be >= 2, got {inflation_factor}") + self.base_process = base_process + self.inflation_factor = inflation_factor + self._base_vocab_size = base_process.vocab_size + self._inflated_vocab_size = inflation_factor * base_process.vocab_size + + @property + def vocab_size(self) -> int: + """The number of inflated observations: K * base vocab size.""" + return self._inflated_vocab_size + + @property + def initial_state(self) -> State: + """The initial state, identical to the base process.""" + return self.base_process.initial_state + + @eqx.filter_jit + def emit_observation(self, state: State, key: chex.PRNGKey) -> chex.Array: + """Emit an inflated observation: sample base token then add uniform noise prefix.""" + k1, k2 = jax.random.split(key) + base_obs = self.base_process.emit_observation(state, k1) + noise_prefix = jax.random.randint(k2, (), 0, self.inflation_factor) + return noise_prefix * self._base_vocab_size + base_obs + + @eqx.filter_jit + def transition_states(self, state: State, obs: chex.Array) -> State: + """Update state using only the base token (noise prefix is discarded).""" + base_obs = jnp.mod(obs, self._base_vocab_size) + return self.base_process.transition_states(state, base_obs) + + @eqx.filter_jit + def observation_probability_distribution(self, state: State) -> jax.Array: + """Compute P(inflated_obs | state) = P(base_obs | state) / K for each noise variant.""" + base_dist = self.base_process.observation_probability_distribution(state) + return jnp.tile(base_dist / self.inflation_factor, self.inflation_factor) + + @eqx.filter_jit + def log_observation_probability_distribution(self, log_belief_state: State) -> jax.Array: + """Compute log P(inflated_obs | state) = log P(base_obs | state) - log(K).""" + base_log_dist = self.base_process.log_observation_probability_distribution(log_belief_state) + return jnp.tile(base_log_dist - jnp.log(self.inflation_factor), self.inflation_factor) + + @eqx.filter_jit + def probability(self, observations: jax.Array) -> jax.Array: + """Compute P(inflated_seq) = P(base_seq) * (1/K)^T.""" + base_obs = jnp.mod(observations, self._base_vocab_size) + base_prob = self.base_process.probability(base_obs) + return base_prob * (1.0 / self.inflation_factor) ** observations.shape[0] + + @eqx.filter_jit + def log_probability(self, observations: jax.Array) -> jax.Array: + """Compute log P(inflated_seq) = log P(base_seq) - T * log(K).""" + base_obs = jnp.mod(observations, self._base_vocab_size) + base_log_prob = self.base_process.log_probability(base_obs) + return base_log_prob - observations.shape[0] * jnp.log(self.inflation_factor) diff --git a/simplexity/generative_processes/nonergodic_generative_process.py b/simplexity/generative_processes/nonergodic_generative_process.py new file mode 100644 index 00000000..a45eabf7 --- /dev/null +++ b/simplexity/generative_processes/nonergodic_generative_process.py @@ -0,0 +1,526 @@ +"""Nonergodic generative process that composes multiple GenerativeProcess components.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import partial +from typing import NamedTuple + +import chex +import equinox as eqx +import jax +import jax.numpy as jnp + +from simplexity.generative_processes.generative_process import GenerativeProcess +from simplexity.utils.jnp_utils import resolve_jax_device + +ComponentState = jax.Array | tuple[jax.Array, ...] + + +class _GenerationLayout(NamedTuple): + """Static metadata needed to flatten and restore heterogeneous states.""" + + flat_sizes: tuple[int, ...] + state_templates: tuple[ComponentState, ...] + max_flat_size: int + + +def _get_flat_size(state: ComponentState) -> int: + """Get total number of elements in a component state. + + Args: + state: Either a flat jax.Array or a tuple of arrays (FactoredState) + + Returns: + Total element count across all arrays in the state + """ + if isinstance(state, tuple): + return sum(arr.size for arr in state) + return state.size + + +def _flatten_state(state: ComponentState) -> jax.Array: + """Flatten a component state to a 1D array. + + Args: + state: Either a flat jax.Array or a tuple of arrays (FactoredState) + + Returns: + 1D array containing all elements from the state + """ + if isinstance(state, tuple): + return jnp.concatenate([arr.ravel() for arr in state]) + return state.ravel() + + +def _flatten_and_pad_state(state: ComponentState, max_flat_size: int) -> jax.Array: + """Flatten a state and pad it to the shared switch-compatible size.""" + flat = _flatten_state(state) + return jnp.pad(flat, (0, max_flat_size - flat.size)) + + +def _unflatten_state(flat: jax.Array, template: ComponentState) -> ComponentState: + """Restore original state structure from a flattened 1D array. + + Uses the template to determine: + - For flat arrays: the target shape + - For tuples: the number of arrays, each array's shape, and split points + + Args: + flat: 1D array containing state data + template: Original state (used only for shape/structure, not values) + + Returns: + State with same structure as template, populated with data from flat + + Note: + Uses dynamic_slice instead of split to avoid ConcretizationTypeError + inside jax.lax.switch. The template shapes are concrete (known at trace + time), so we can compute offsets as Python ints. + """ + if isinstance(template, tuple): + offset = 0 + parts = [] + for t in template: + part = jax.lax.dynamic_slice(flat, (offset,), (t.size,)) + parts.append(part.reshape(t.shape)) + offset += t.size + return tuple(parts) + return flat.reshape(template.shape) + + +def _unpad_and_unflatten_state(padded: jax.Array, original_size: int, template: ComponentState) -> ComponentState: + """Remove padding and restore the component state structure.""" + return _unflatten_state(padded[:original_size], template) + + +def _keep_state(state: ComponentState, _obs: chex.Array) -> ComponentState: + """Return the existing state unchanged.""" + return state + + +class NonErgodicState(NamedTuple): + """State for nonergodic generative process. + + Attributes: + component_beliefs: P(component_i | observations_so_far), shape [num_components]. + Sums to 1. For generation, becomes one-hot after first emission. + component_states: Tuple of per-component state arrays. Each element has the + shape expected by that component's GenerativeProcess. + """ + + component_beliefs: jax.Array + component_states: tuple[ComponentState, ...] + + +class NonErgodicGenerativeProcess(GenerativeProcess[NonErgodicState]): + """A nonergodic mixture of generative processes. + + Composes multiple GenerativeProcess instances into a block diagonal structure + where no transitions occur between components. The process maintains belief + over which component generated the sequence, updated via Bayes rule. + + Key efficiency: Does NOT materialize a full block diagonal matrix. Instead, + it stores component processes directly and updates only the relevant beliefs. + + For generation: A single component is sampled at the start of each sequence + based on component_weights, and all observations come from that component. + + For inference: Beliefs are tracked across all components via Bayesian filtering. + + Attributes: + components: Tuple of component GenerativeProcess instances. + component_weights: Initial mixture weights (normalized to sum to 1). + vocab_maps: Per-component mapping from local vocab to global vocab. + _vocab_size: Unified vocabulary size across all components. + _inverse_vocab_maps: Per-component mapping from global vocab to local vocab. + device: JAX device for arrays. + """ + + components: tuple[GenerativeProcess, ...] + component_weights: jax.Array + vocab_maps: tuple[jax.Array, ...] + _vocab_size: int + _inverse_vocab_maps: tuple[jax.Array, ...] + device: jax.Device # type: ignore[valid-type] + + def __init__( + self, + components: Sequence[GenerativeProcess], + component_weights: jax.Array | Sequence[float], + vocab_maps: Sequence[Sequence[int]] | None = None, + device: str | None = None, + ) -> None: + """Initialize nonergodic generative process. + + Args: + components: Sequence of GenerativeProcess instances to compose. + component_weights: Initial mixture weights. Will be normalized to sum to 1. + vocab_maps: Optional per-component vocab mappings. vocab_maps[i] maps + component i's local token indices to global token indices. + If None, assumes all components share the same vocab [0, 1, ..., V-1]. + device: Device to place arrays on (e.g., "cpu", "gpu"). + + Raises: + ValueError: If components is empty, weights don't match component count, + vocab map count doesn't match component count, or a component + vocab_map contains duplicate global token indices. + """ + if len(components) == 0: + raise ValueError("Must provide at least one component") + + self.device = resolve_jax_device(device) + self.components = tuple(components) + + weights = jnp.array(component_weights) + if weights.shape[0] != len(components): + raise ValueError( + f"Number of weights ({weights.shape[0]}) must match number of components ({len(components)})" + ) + if jnp.any(weights < 0): + raise ValueError("Component weights must be non-negative") + self.component_weights = weights / jnp.sum(weights) + self.component_weights = jax.device_put(self.component_weights, self.device) + + if vocab_maps is None: + vocab_maps = [list(range(c.vocab_size)) for c in components] + elif len(vocab_maps) != len(self.components): + raise ValueError("Length of vocab maps must equal length of components.") + + for i, vm in enumerate(vocab_maps): + if len(set(vm)) != len(vm): + raise ValueError(f"vocab_maps[{i}] must not contain duplicate global token indices") + + self.vocab_maps = tuple(jax.device_put(jnp.array(vm, dtype=jnp.int32), self.device) for vm in vocab_maps) + self._vocab_size = max(max(vm) for vm in vocab_maps) + 1 + + inverse_maps = [] + for vm in vocab_maps: + inv = jnp.full((self._vocab_size,), -1, dtype=jnp.int32) + for local_idx, global_idx in enumerate(vm): + inv = inv.at[global_idx].set(local_idx) + inverse_maps.append(jax.device_put(inv, self.device)) + self._inverse_vocab_maps = tuple(inverse_maps) + + @property + def vocab_size(self) -> int: + """Unified vocabulary size across all components.""" + return self._vocab_size + + @property + def initial_state(self) -> NonErgodicState: + """Initial state with component weights and per-component initial states.""" + return NonErgodicState( + component_beliefs=self.component_weights, + component_states=tuple(c.initial_state for c in self.components), + ) + + @eqx.filter_jit + def observation_probability_distribution(self, state: NonErgodicState) -> jax.Array: + """Compute P(global_obs | state) as weighted sum over components. + + For each global observation token: + P(obs | state) = sum_i P(component_i | state) * P(obs | component_i, state_i) + + Where P(obs | component_i, state_i) is computed by: + 1. Getting the probability from component i's distribution + 2. Mapping to global vocab via vocab_map + 3. Returning 0 if the global obs is not in component i's vocab + """ + global_dist = jnp.zeros(self._vocab_size) + + for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=True)): + comp_state = state.component_states[i] + local_dist = component.observation_probability_distribution(comp_state) + component_contrib = jnp.zeros(self._vocab_size).at[vm].add(local_dist) + global_dist += state.component_beliefs[i] * component_contrib + + return global_dist + + @eqx.filter_jit + def log_observation_probability_distribution(self, log_belief_state: NonErgodicState) -> jax.Array: + """Compute log P(global_obs | state). + + Expects log-space component_beliefs and component_states. Unmapped tokens + get -inf. Component beliefs weight via addition in log space, then combined + via logsumexp across components. + """ + log_probs = [] + + for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=True)): + comp_log_state = log_belief_state.component_states[i] + comp_log_belief = log_belief_state.component_beliefs[i] + + local_log_dist = component.log_observation_probability_distribution(comp_log_state) + global_log_dist = jnp.full(self._vocab_size, -jnp.inf) + global_log_dist = global_log_dist.at[vm].set(local_log_dist) + log_probs.append(comp_log_belief + global_log_dist) + + log_probs_stacked = jnp.stack(log_probs, axis=0) + return jax.nn.logsumexp(log_probs_stacked, axis=0) + + @eqx.filter_jit + def emit_observation(self, state: NonErgodicState, key: chex.PRNGKey) -> chex.Array: + """Emit an observation by sampling from the mixture distribution. + + First samples a component based on component_beliefs, then emits from + that component and maps to global vocab. + """ + key1, key2 = jax.random.split(key) + component_idx = jax.random.categorical(key1, jnp.log(state.component_beliefs)) + + def emit_from_component(i: int, k: chex.PRNGKey) -> chex.Array: + comp_state = state.component_states[i] + local_obs = self.components[i].emit_observation(comp_state, k) + return self.vocab_maps[i][local_obs] + + global_obs = jax.lax.switch( + component_idx, + [partial(emit_from_component, i) for i in range(len(self.components))], + key2, + ) + + return global_obs + + def _update_component_for_observation( + self, + component: GenerativeProcess, + inv_map: jax.Array, + comp_state: ComponentState, + obs: chex.Array, + ) -> tuple[ComponentState, jax.Array]: + """Update one component state and return its observation likelihood.""" + local_obs = inv_map[obs] + local_dist = component.observation_probability_distribution(comp_state) + likelihood = jnp.where( + local_obs >= 0, + local_dist[jnp.clip(local_obs, 0, local_dist.shape[0] - 1)], + 0.0, + ) + + def transition_component(state: ComponentState, mapped_obs: chex.Array) -> ComponentState: + return component.transition_states(state, mapped_obs) + + new_comp_state = jax.lax.cond( + likelihood > 0, + transition_component, + _keep_state, + comp_state, + local_obs, + ) + return new_comp_state, likelihood + + @eqx.filter_jit + def transition_states(self, state: NonErgodicState, obs: chex.Array) -> NonErgodicState: + """Update state given observation using Bayesian filtering. + + For each component: computes P(obs | component_i) as the likelihood + (0 if obs not in that component's vocab), conditionally updates the + component's internal state only when likelihood > 0, then applies + Bayes rule to update component_beliefs. Falls back to prior beliefs + if all likelihoods are 0. + """ + new_component_states = [] + likelihoods = [] + + for i, (component, inv_map) in enumerate(zip(self.components, self._inverse_vocab_maps, strict=True)): + comp_state = state.component_states[i] + new_comp_state, likelihood = self._update_component_for_observation(component, inv_map, comp_state, obs) + likelihoods.append(likelihood) + new_component_states.append(new_comp_state) + + likelihoods_arr = jnp.array(likelihoods) + unnorm_beliefs = state.component_beliefs * likelihoods_arr + normalizer = jnp.sum(unnorm_beliefs) + new_beliefs = jnp.where( + normalizer > 0, + unnorm_beliefs / normalizer, + state.component_beliefs, + ) + + return NonErgodicState( + component_beliefs=new_beliefs, + component_states=tuple(new_component_states), + ) + + @eqx.filter_jit + def probability(self, observations: jax.Array) -> jax.Array: + """Compute P(observations) by marginalizing over components. + + P(obs_1:T) = sum_i P(component_i) * P(obs_1:T | component_i) + """ + + def compute_component_prob(i: int) -> jax.Array: + component = self.components[i] + inv_map = self._inverse_vocab_maps[i] + local_obs = inv_map[observations] + all_valid = jnp.all(local_obs >= 0) + + def compute_prob(lo: jax.Array) -> jax.Array: + return component.probability(lo) + + prob = jax.lax.cond( + all_valid, + compute_prob, + lambda lo: jnp.array(0.0), + local_obs, + ) + return self.component_weights[i] * prob + + total_prob = jnp.array(0.0) + for i in range(len(self.components)): + total_prob = total_prob + compute_component_prob(i) + + return total_prob + + @eqx.filter_jit + def log_probability(self, observations: jax.Array) -> jax.Array: + """Compute log P(observations) using logsumexp for numerical stability.""" + + def compute_component_log_prob(i: int) -> jax.Array: + component = self.components[i] + inv_map = self._inverse_vocab_maps[i] + local_obs = inv_map[observations] + all_valid = jnp.all(local_obs >= 0) + + def compute_log_prob(lo: jax.Array) -> jax.Array: + return component.log_probability(lo) + + log_prob = jax.lax.cond( + all_valid, + compute_log_prob, + lambda lo: jnp.array(-jnp.inf), + local_obs, + ) + return jnp.log(self.component_weights[i]) + log_prob + + log_probs = jnp.array([compute_component_log_prob(i) for i in range(len(self.components))]) + return jax.nn.logsumexp(log_probs) + + def _generate_component_step( + self, + i: int, + padded_state: jax.Array, + step_key: chex.PRNGKey, + layout: _GenerationLayout, + ) -> tuple[jax.Array, chex.Array]: + """Advance one selected component by a single generation step.""" + real_state = _unpad_and_unflatten_state(padded_state, layout.flat_sizes[i], layout.state_templates[i]) + local_obs = self.components[i].emit_observation(real_state, step_key) + new_real_state = self.components[i].transition_states(real_state, local_obs) + new_padded_state = _flatten_and_pad_state(new_real_state, layout.max_flat_size) + global_obs = self.vocab_maps[i][local_obs] + return new_padded_state, global_obs + + def _scan_component_generation( + self, + component_idx: jax.Array, + padded_states: tuple[jax.Array, ...], + keys: jax.Array, + layout: _GenerationLayout, + ) -> tuple[tuple[jax.Array, ...], chex.Array]: + """Generate observations while updating only the sampled component state.""" + num_components = len(self.components) + + def scan_step( + carry: tuple[jax.Array, tuple[jax.Array, ...]], step_key: chex.PRNGKey + ) -> tuple[tuple[jax.Array, tuple[jax.Array, ...]], chex.Array]: + idx, padded_comp_states = carry + + new_padded_state, global_obs = jax.lax.switch( + idx, + [ + partial( + self._generate_component_step, + i, + padded_comp_states[i], + step_key, + layout, + ) + for i in range(num_components) + ], + ) + + new_padded_comp_states = tuple( + jax.lax.select(idx == i, new_padded_state, padded_comp_states[i]) for i in range(num_components) + ) + + return (idx, new_padded_comp_states), global_obs + + init_carry = (component_idx, padded_states) + (_, final_padded_states), observations = jax.lax.scan(scan_step, init_carry, keys) + return final_padded_states, observations + + def _generate_state_trajectory( + self, state: NonErgodicState, observations: chex.Array + ) -> tuple[NonErgodicState, chex.Array]: + """Reconstruct the per-token belief trajectory from generated observations.""" + + def inference_step(carry_state: NonErgodicState, obs: chex.Array) -> tuple[NonErgodicState, NonErgodicState]: + new_state = self.transition_states(carry_state, obs) + return new_state, carry_state + + _, state_trajectory = jax.lax.scan(inference_step, state, observations) + return state_trajectory, observations + + @eqx.filter_vmap(in_axes=(None, 0, 0, None, None)) + def generate( + self, + state: NonErgodicState, + key: chex.PRNGKey, + sequence_len: int, + return_all_states: bool, + ) -> tuple[NonErgodicState, chex.Array]: + """Generate a sequence from a single sampled component. + + Unlike inference (which tracks beliefs across all components), generation + samples ONE component at the start and generates entirely from that component. + + This method is vmapped, so inside the function body we work with unbatched + (single-element) states and keys. We cannot call component.generate() here + because that method is also vmapped and expects batched inputs. Instead, + we implement generation directly using jax.lax.scan over emit_observation + and transition_states. + + Because jax.lax.switch requires all branches to return the same shape, and + components may have different state types (HMM: flat array vs Factored: tuple + of arrays), we flatten each state to 1D, pad to a common max size for switch + compatibility, and unflatten back to native structures after processing. + + Args: + state: Initial NonErgodicState with component_beliefs and component_states. + The batch dimension is handled by vmap. + key: Random key for this sequence. + sequence_len: Length of sequence to generate. + return_all_states: If True, return state trajectory at each timestep. + + Returns: + Tuple of (final_state or state_trajectory, observations). + States are NonErgodicState. Observations are in global vocab space. + """ + key1, key2 = jax.random.split(key) + keys = jax.random.split(key2, sequence_len) + + component_idx = jax.random.categorical(key1, jnp.log(state.component_beliefs)) + layout = _GenerationLayout( + flat_sizes=tuple(_get_flat_size(s) for s in state.component_states), + state_templates=state.component_states, + max_flat_size=max(_get_flat_size(s) for s in state.component_states), + ) + padded_states = tuple(_flatten_and_pad_state(s, layout.max_flat_size) for s in state.component_states) + final_padded_states, observations = self._scan_component_generation(component_idx, padded_states, keys, layout) + + final_comp_states = tuple( + _unpad_and_unflatten_state(final_padded_states[i], layout.flat_sizes[i], layout.state_templates[i]) + for i in range(len(self.components)) + ) + + one_hot_beliefs = jax.nn.one_hot(component_idx, len(self.components), dtype=self.component_weights.dtype) + + if return_all_states: + return self._generate_state_trajectory(state, observations) + + return NonErgodicState( + component_beliefs=one_hot_beliefs, + component_states=final_comp_states, + ), observations diff --git a/tests/end_to_end/configs/generative_process/nonergodic_example.yaml b/tests/end_to_end/configs/generative_process/nonergodic_example.yaml new file mode 100644 index 00000000..199656d9 --- /dev/null +++ b/tests/end_to_end/configs/generative_process/nonergodic_example.yaml @@ -0,0 +1,39 @@ +# Nonergodic Generative Process Example +# A mixture of independent generative processes with block diagonal structure. +# No transitions occur between components - the process "picks" a component +# at the start and stays with it forever. + +name: nonergodic_example +base_vocab_size: ??? +vocab_size: ??? + +instance: + _target_: simplexity.generative_processes.builder.build_nonergodic_process_from_spec + + components: + # Component 0: mess3 HMM with specific parameters + - component_type: hmm + process_name: mess3 + process_params: + x: 0.15 + a: 0.6 + + # Component 1: mess3 HMM with different parameters + - component_type: hmm + process_name: mess3 + process_params: + x: 0.5 + a: 0.6 + + # Initial mixture weights (will be normalized) + # 60% chance of starting in component 0, 40% in component 1 + component_weights: [0.6, 0.4] + +bos_token: ??? +eos_token: null + +# Interpretation: +# - With probability 0.6, sequences come from mess3(x=0.15, a=0.6) +# - With probability 0.4, sequences come from mess3(x=0.5, a=0.6) +# - Once a component is "chosen" by the initial state, all future observations +# come from that component (truly nonergodic - no mixing/switching) diff --git a/tests/generative_processes/test_builder.py b/tests/generative_processes/test_builder.py index bd80643b..03b4386e 100644 --- a/tests/generative_processes/test_builder.py +++ b/tests/generative_processes/test_builder.py @@ -23,8 +23,10 @@ build_generalized_hidden_markov_model, build_hidden_markov_model, build_matrices_from_spec, + build_nonergodic_disjoint_vocab, build_nonergodic_hidden_markov_model, build_nonergodic_initial_state, + build_nonergodic_partial_overlap, build_nonergodic_transition_matrices, build_symmetric_from_spec, build_transition_coupled_from_spec, @@ -603,3 +605,225 @@ def test_build_chain_from_spec_empty_chain_raises(): """Empty chain should raise ValueError.""" with pytest.raises(ValueError, match="chain must contain at least one node"): build_chain_from_spec([]) + + +# --- Tests for IndependentFactoredGenerativeProcess in build_factored_process --- + + +def test_build_factored_process_independent_returns_independent_subclass(components_spec): + """build_factored_process with independent structure should return IndependentFactoredGenerativeProcess.""" + from simplexity.generative_processes.independent_factored_generative_process import ( + IndependentFactoredGenerativeProcess, + ) + + component_types, transition_matrices, normalizing_eigenvectors, initial_states = build_matrices_from_spec( + components_spec + ) + process = build_factored_process( + structure_type="independent", + component_types=component_types, + transition_matrices=transition_matrices, + normalizing_eigenvectors=normalizing_eigenvectors, + initial_states=initial_states, + ) + assert isinstance(process, IndependentFactoredGenerativeProcess) + assert isinstance(process.structure, IndependentStructure) + + +def test_build_factored_process_independent_passes_noise_epsilon(components_spec): + """noise_epsilon should be propagated to IndependentFactoredGenerativeProcess.""" + component_types, transition_matrices, normalizing_eigenvectors, initial_states = build_matrices_from_spec( + components_spec + ) + process = build_factored_process( + structure_type="independent", + component_types=component_types, + transition_matrices=transition_matrices, + normalizing_eigenvectors=normalizing_eigenvectors, + initial_states=initial_states, + noise_epsilon=0.05, + ) + assert process.noise_epsilon == 0.05 + + +def test_build_factored_process_from_spec_independent_returns_independent_subclass(components_spec): + """build_factored_process_from_spec with independent structure returns IndependentFactoredGenerativeProcess.""" + from simplexity.generative_processes.independent_factored_generative_process import ( + IndependentFactoredGenerativeProcess, + ) + + process = build_factored_process_from_spec(structure_type="independent", spec=components_spec) + assert isinstance(process, IndependentFactoredGenerativeProcess) + + +# --- Tests for build_nonergodic_disjoint_vocab --- + + +TWO_COINS = [ + {"component_type": "hmm", "process_name": "coin", "process_params": {"p": 0.6}}, + {"component_type": "hmm", "process_name": "coin", "process_params": {"p": 0.4}}, +] + + +class TestBuildNonErgodicDisjointVocab: + """Tests for build_nonergodic_disjoint_vocab.""" + + def test_vocab_maps_are_non_overlapping(self): + """Each component should get a unique, non-overlapping vocab range.""" + process = build_nonergodic_disjoint_vocab(components=TWO_COINS, component_weights=[0.5, 0.5]) + vm0 = set(process.vocab_maps[0].tolist()) + vm1 = set(process.vocab_maps[1].tolist()) + assert vm0 == {0, 1} + assert vm1 == {2, 3} + assert vm0.isdisjoint(vm1) + + def test_vocab_size_is_sum_of_components(self): + """Total vocab size should be sum of all component vocab sizes.""" + process = build_nonergodic_disjoint_vocab(components=TWO_COINS, component_weights=[0.5, 0.5]) + assert process.vocab_size == 4 + + def test_distribution_sums_to_one(self): + """Observation distribution should be a valid probability distribution.""" + process = build_nonergodic_disjoint_vocab(components=TWO_COINS, component_weights=[0.5, 0.5]) + dist = process.observation_probability_distribution(process.initial_state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_three_components_disjoint(self): + """Three-component disjoint should produce three non-overlapping ranges.""" + components = [ + {"component_type": "hmm", "process_name": "coin", "process_params": {"p": 0.5}}, + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.15, "a": 0.6}}, + {"component_type": "hmm", "process_name": "coin", "process_params": {"p": 0.3}}, + ] + process = build_nonergodic_disjoint_vocab(components=components, component_weights=[0.4, 0.3, 0.3]) + assert process.vocab_size == 7 # 2 + 3 + 2 + all_tokens = set() + for vm in process.vocab_maps: + tokens = set(vm.tolist()) + assert all_tokens.isdisjoint(tokens) + all_tokens.update(tokens) + + +# --- Tests for build_nonergodic_partial_overlap --- + + +class TestBuildNonErgodicPartialOverlap: + """Tests for build_nonergodic_partial_overlap.""" + + def test_prefix_mode_shared_and_unique_tokens(self): + """Prefix mode: components should share some tokens and have unique tokens.""" + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode="prefix" + ) + vm0 = process.vocab_maps[0].tolist() + vm1 = process.vocab_maps[1].tolist() + assert vm0 == [0, 1] + assert vm1 == [0, 2] + + def test_prefix_mode_full_overlap(self): + """overlap_frac=1.0 should give fully shared vocabularies.""" + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=1.0, mode="prefix" + ) + vm0 = process.vocab_maps[0].tolist() + vm1 = process.vocab_maps[1].tolist() + assert vm0 == vm1 + + def test_prefix_mode_zero_overlap_is_disjoint(self): + """overlap_frac=0.0 with prefix mode should produce disjoint vocabs.""" + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.0, mode="prefix" + ) + vm0 = set(process.vocab_maps[0].tolist()) + vm1 = set(process.vocab_maps[1].tolist()) + assert vm0.isdisjoint(vm1) + + def test_sliding_mode_produces_offset_maps(self): + """Sliding mode should produce overlapping ranges offset by V * (1 - overlap_frac).""" + components = [ + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.15, "a": 0.6}}, + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.35, "a": 0.6}}, + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.5, "a": 0.6}}, + ] + # V=3, overlap_frac=2/3 => n_unique=1, offset=1 + process = build_nonergodic_partial_overlap( + components=components, component_weights=[0.333, 0.333, 0.334], overlap_frac=2.0 / 3.0, mode="sliding" + ) + assert process.vocab_maps[0].tolist() == [0, 1, 2] + assert process.vocab_maps[1].tolist() == [1, 2, 3] + assert process.vocab_maps[2].tolist() == [2, 3, 4] + + def test_sliding_mode_full_overlap(self): + """Sliding with overlap_frac=1.0 should still offset by at least 1.""" + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=1.0, mode="sliding" + ) + vm0 = process.vocab_maps[0].tolist() + vm1 = process.vocab_maps[1].tolist() + # offset = max(1, 0) = 1, so C0=[0,1], C1=[1,2] + assert vm0 == [0, 1] + assert vm1 == [1, 2] + + def test_random_mode_independent_sampling(self): + """Random mode should independently sample V tokens per component from the global pool.""" + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode="random", seed=42 + ) + v = process.components[0].vocab_size + for vm in process.vocab_maps: + assert len(vm.tolist()) == v + assert len(set(vm.tolist())) == v # no duplicates within a component + + def test_random_mode_is_deterministic_with_seed(self): + """Same seed should produce identical vocab maps.""" + p1 = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode="random", seed=123 + ) + p2 = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode="random", seed=123 + ) + for vm1, vm2 in zip(p1.vocab_maps, p2.vocab_maps, strict=True): + assert vm1.tolist() == vm2.tolist() + + def test_random_mode_different_seeds_differ(self): + """Different seeds should produce different vocab maps.""" + three_mess3 = [ + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.15, "a": 0.6}}, + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.35, "a": 0.6}}, + {"component_type": "hmm", "process_name": "mess3", "process_params": {"x": 0.5, "a": 0.6}}, + ] + weights = [0.333, 0.333, 0.334] + p1 = build_nonergodic_partial_overlap( + components=three_mess3, component_weights=weights, overlap_frac=0.5, mode="random", seed=1 + ) + p2 = build_nonergodic_partial_overlap( + components=three_mess3, component_weights=weights, overlap_frac=0.5, mode="random", seed=2 + ) + maps_differ = any(vm1.tolist() != vm2.tolist() for vm1, vm2 in zip(p1.vocab_maps, p2.vocab_maps, strict=True)) + assert maps_differ + + def test_random_mode_requires_seed(self): + """Random mode without seed should raise ValueError.""" + with pytest.raises(ValueError, match="seed is required"): + build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode="random" + ) + + def test_distribution_sums_to_one_all_modes(self): + """Observation distribution should be valid for all modes.""" + for mode, kwargs in [("prefix", {}), ("sliding", {}), ("random", {"seed": 42})]: + process = build_nonergodic_partial_overlap( + components=TWO_COINS, component_weights=[0.5, 0.5], overlap_frac=0.5, mode=mode, **kwargs + ) + dist = process.observation_probability_distribution(process.initial_state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_unknown_mode_raises(self): + """Unknown mode should raise ValueError.""" + with pytest.raises(ValueError, match="Unknown mode"): + build_nonergodic_partial_overlap( + components=TWO_COINS, + component_weights=[0.5, 0.5], + overlap_frac=0.5, + mode="bogus", # type: ignore[arg-type] + ) diff --git a/tests/generative_processes/test_inflated_vocabulary_process.py b/tests/generative_processes/test_inflated_vocabulary_process.py new file mode 100644 index 00000000..793aba79 --- /dev/null +++ b/tests/generative_processes/test_inflated_vocabulary_process.py @@ -0,0 +1,349 @@ +"""Tests for InflatedVocabularyProcess.""" + +# pylint: disable-all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from simplexity.generative_processes.builder import ( + build_generalized_hidden_markov_model, + build_hidden_markov_model, + build_inflated_process, + build_inflated_process_from_spec, +) +from simplexity.generative_processes.inflated_vocabulary_process import InflatedVocabularyProcess +from simplexity.generative_processes.nonergodic_generative_process import ( + NonErgodicGenerativeProcess, + NonErgodicState, +) + + +class TestBasicProperties: + """Tests for basic properties of InflatedVocabularyProcess.""" + + @pytest.fixture + def coin_k3(self): + """Coin process with K=3 inflation.""" + coin = build_hidden_markov_model("coin", {"p": 0.7}) + return InflatedVocabularyProcess(coin, inflation_factor=3) + + @pytest.fixture + def mess3_k3(self): + """Mess3 process with K=3 inflation.""" + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + return InflatedVocabularyProcess(mess3, inflation_factor=3) + + def test_vocab_size_coin(self, coin_k3: InflatedVocabularyProcess): + assert coin_k3.vocab_size == 6 + + def test_vocab_size_mess3(self, mess3_k3: InflatedVocabularyProcess): + assert mess3_k3.vocab_size == 9 + + def test_initial_state_matches_base(self, coin_k3: InflatedVocabularyProcess): + chex.assert_trees_all_close(coin_k3.initial_state, coin_k3.base_process.initial_state) + + def test_inflation_factor_stored(self, coin_k3: InflatedVocabularyProcess): + assert coin_k3.inflation_factor == 3 + + def test_invalid_inflation_factor_raises(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + with pytest.raises(ValueError, match="inflation_factor must be >= 2"): + InflatedVocabularyProcess(coin, inflation_factor=1) + + def test_invalid_inflation_factor_zero_raises(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + with pytest.raises(ValueError, match="inflation_factor must be >= 2"): + InflatedVocabularyProcess(coin, inflation_factor=0) + + +class TestObservationDistribution: + """Tests for observation probability distribution.""" + + @pytest.fixture + def coin_k3(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + return InflatedVocabularyProcess(coin, inflation_factor=3) + + def test_distribution_sums_to_one(self, coin_k3: InflatedVocabularyProcess): + state = coin_k3.initial_state + dist = coin_k3.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_distribution_has_correct_size(self, coin_k3: InflatedVocabularyProcess): + state = coin_k3.initial_state + dist = coin_k3.observation_probability_distribution(state) + assert dist.shape == (6,) + + def test_distribution_spreads_uniformly(self, coin_k3: InflatedVocabularyProcess): + """Each base token's prob is split equally among K noise variants.""" + state = coin_k3.initial_state + dist = coin_k3.observation_probability_distribution(state) + expected = jnp.array([0.7 / 3, 0.3 / 3, 0.7 / 3, 0.3 / 3, 0.7 / 3, 0.3 / 3]) + chex.assert_trees_all_close(dist, expected, atol=1e-6) + + def test_noise_variants_have_equal_probability(self): + """All K noise variants of the same base token should have identical probability.""" + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + inflated = InflatedVocabularyProcess(mess3, inflation_factor=4) + state = inflated.initial_state + dist = inflated.observation_probability_distribution(state) + v_base = mess3.vocab_size + for base_tok in range(v_base): + probs = [float(dist[n * v_base + base_tok]) for n in range(4)] + for p in probs[1:]: + chex.assert_trees_all_close(p, probs[0], atol=1e-6) + + def test_log_distribution_consistent(self, coin_k3: InflatedVocabularyProcess): + state = coin_k3.initial_state + log_state = jnp.log(state) + dist = coin_k3.observation_probability_distribution(state) + log_dist = coin_k3.log_observation_probability_distribution(log_state) + chex.assert_trees_all_close(log_dist, jnp.log(dist), atol=1e-5) + + +class TestTransitionStates: + """Tests for state transitions.""" + + def test_noise_prefix_does_not_affect_state(self): + """All K noise variants of the same base token should produce identical states.""" + even_ones = build_hidden_markov_model("even_ones", {"p": 0.5}) + inflated = InflatedVocabularyProcess(even_ones, inflation_factor=3) + state = inflated.initial_state + v_base = even_ones.vocab_size + + for base_tok in range(v_base): + states = [inflated.transition_states(state, jnp.array(n * v_base + base_tok)) for n in range(3)] + for s in states[1:]: + chex.assert_trees_all_close(s, states[0], atol=1e-6) + + def test_transition_matches_base_process(self): + """Transitioning with an inflated token should match transitioning with the base token.""" + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + inflated = InflatedVocabularyProcess(mess3, inflation_factor=3) + state = mess3.initial_state + + for base_tok in range(mess3.vocab_size): + base_state = mess3.transition_states(state, jnp.array(base_tok)) + inflated_state = inflated.transition_states(state, jnp.array(base_tok)) + chex.assert_trees_all_close(inflated_state, base_state, atol=1e-6) + + inflated_state_noisy = inflated.transition_states(state, jnp.array(2 * mess3.vocab_size + base_tok)) + chex.assert_trees_all_close(inflated_state_noisy, base_state, atol=1e-6) + + +class TestProbability: + """Tests for sequence probability computation.""" + + def test_probability_scales_by_inflation_penalty(self): + """P(inflated_seq) = P(base_seq) / K^T.""" + coin = build_hidden_markov_model("coin", {"p": 0.7}) + k = 3 + inflated = InflatedVocabularyProcess(coin, inflation_factor=k) + + base_seq = jnp.array([0, 1, 0]) + base_prob = coin.probability(base_seq) + inflated_prob = inflated.probability(base_seq) + + expected = base_prob / (k**3) + chex.assert_trees_all_close(inflated_prob, expected, atol=1e-6) + + def test_probability_same_base_different_noise(self): + """Different noise prefixes with same base sequence should have same probability.""" + coin = build_hidden_markov_model("coin", {"p": 0.7}) + inflated = InflatedVocabularyProcess(coin, inflation_factor=3) + + seq_noise0 = jnp.array([0, 1, 0]) + seq_noise1 = jnp.array([2, 3, 2]) + seq_noise2 = jnp.array([4, 5, 4]) + + p0 = inflated.probability(seq_noise0) + p1 = inflated.probability(seq_noise1) + p2 = inflated.probability(seq_noise2) + + chex.assert_trees_all_close(p0, p1, atol=1e-6) + chex.assert_trees_all_close(p1, p2, atol=1e-6) + + def test_log_probability_consistent(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + inflated = InflatedVocabularyProcess(coin, inflation_factor=3) + + seq = jnp.array([0, 3, 1, 4]) + prob = inflated.probability(seq) + log_prob = inflated.log_probability(seq) + chex.assert_trees_all_close(log_prob, jnp.log(prob), atol=1e-5) + + def test_optimal_loss_increases_by_log_k(self): + """Average per-token loss should increase by exactly log(K).""" + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + k = 4 + inflated = InflatedVocabularyProcess(mess3, inflation_factor=k) + + key = jax.random.PRNGKey(42) + state = mess3.initial_state + batch_state = jnp.broadcast_to(state, (100,) + state.shape) + keys = jax.random.split(key, 100) + + _, base_seqs = mess3.generate(batch_state, keys, 200, False) + base_log_probs = jax.vmap(mess3.log_probability)(base_seqs) + base_avg_loss = -jnp.mean(base_log_probs) / 200 + + inflated_log_probs = jax.vmap(inflated.log_probability)(base_seqs) + inflated_avg_loss = -jnp.mean(inflated_log_probs) / 200 + + expected_increase = jnp.log(jnp.array(k, dtype=jnp.float32)) + chex.assert_trees_all_close(inflated_avg_loss - base_avg_loss, expected_increase, atol=0.01) + + +class TestGeneration: + """Tests for sequence generation.""" + + def test_generate_valid_tokens(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + inflated = InflatedVocabularyProcess(coin, inflation_factor=3) + + state = inflated.initial_state + batch_state = jnp.broadcast_to(state, (4,) + state.shape) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + _, observations = inflated.generate(batch_state, keys, 20, False) + + assert observations.shape == (4, 20) + assert jnp.all(observations >= 0) + assert jnp.all(observations < inflated.vocab_size) + + def test_generate_covers_noise_variants(self): + """Generated tokens should use all noise variants over many samples.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + inflated = InflatedVocabularyProcess(coin, inflation_factor=3) + + state = inflated.initial_state + batch_state = jnp.broadcast_to(state, (50,) + state.shape) + keys = jax.random.split(jax.random.PRNGKey(123), 50) + _, observations = inflated.generate(batch_state, keys, 100, False) + + unique_tokens = jnp.unique(observations.ravel()) + assert unique_tokens.shape[0] == 6 + + def test_generate_with_return_all_states(self): + """Generation with return_all_states=True should return state trajectory.""" + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + inflated = InflatedVocabularyProcess(mess3, inflation_factor=2) + + state = inflated.initial_state + batch_state = jnp.broadcast_to(state, (4,) + state.shape) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + states, observations = inflated.generate(batch_state, keys, 10, True) + + assert observations.shape == (4, 10) + assert states.shape == (4, 10) + state.shape + + def test_base_token_distribution_matches(self): + """Extracting base tokens from inflated generation should match base distribution.""" + coin = build_hidden_markov_model("coin", {"p": 0.8}) + inflated = InflatedVocabularyProcess(coin, inflation_factor=5) + + state = inflated.initial_state + batch_state = jnp.broadcast_to(state, (200,) + state.shape) + keys = jax.random.split(jax.random.PRNGKey(42), 200) + _, observations = inflated.generate(batch_state, keys, 500, False) + + base_tokens = observations % coin.vocab_size + base_freq = jnp.mean(base_tokens == 0) + chex.assert_trees_all_close(base_freq, 0.8, atol=0.03) + + +class TestWithDifferentBaseProcesses: + """Tests for wrapping different process types.""" + + def test_wrap_ghmm(self): + ghmm = build_generalized_hidden_markov_model("tom_quantum", {"alpha": 1.0, "beta": 1.0}) + inflated = InflatedVocabularyProcess(ghmm, inflation_factor=2) + assert inflated.vocab_size == 2 * ghmm.vocab_size + + state = inflated.initial_state + dist = inflated.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + batch_state = jnp.broadcast_to(state, (4,) + state.shape) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + _, observations = inflated.generate(batch_state, keys, 10, False) + assert jnp.all(observations >= 0) + assert jnp.all(observations < inflated.vocab_size) + + def test_wrap_nonergodic(self): + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + nonergodic = NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[0.5, 0.5], + ) + inflated = InflatedVocabularyProcess(nonergodic, inflation_factor=3) + assert inflated.vocab_size == 6 + + state = inflated.initial_state + assert isinstance(state, NonErgodicState) + dist = inflated.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_double_inflation(self): + """Stacking inflation: K1 * K2 total inflation.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + inflated1 = InflatedVocabularyProcess(coin, inflation_factor=2) + inflated2 = InflatedVocabularyProcess(inflated1, inflation_factor=3) + assert inflated2.vocab_size == 12 + + state = inflated2.initial_state + dist = inflated2.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + chex.assert_trees_all_close(dist, jnp.ones(12) / 12, atol=1e-6) + + +class TestBuilder: + """Tests for builder functions.""" + + def test_build_inflated_process(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + inflated = build_inflated_process(coin, inflation_factor=3) + assert isinstance(inflated, InflatedVocabularyProcess) + assert inflated.vocab_size == 6 + + def test_build_inflated_process_from_spec_hmm(self): + inflated = build_inflated_process_from_spec( + base_spec={ + "component_type": "hmm", + "process_name": "mess3", + "process_params": {"x": 0.15, "a": 0.6}, + }, + inflation_factor=3, + ) + assert isinstance(inflated, InflatedVocabularyProcess) + assert inflated.vocab_size == 9 + + def test_build_inflated_process_from_spec_ghmm(self): + inflated = build_inflated_process_from_spec( + base_spec={ + "component_type": "ghmm", + "process_name": "tom_quantum", + "process_params": {"alpha": 1.0, "beta": 1.0}, + }, + inflation_factor=2, + ) + assert isinstance(inflated, InflatedVocabularyProcess) + state = inflated.initial_state + dist = inflated.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_build_inflated_process_from_spec_unknown_type_raises(self): + with pytest.raises(ValueError, match="Unknown base component_type"): + build_inflated_process_from_spec( + base_spec={"component_type": "unknown", "process_name": "coin"}, + inflation_factor=2, + ) diff --git a/tests/generative_processes/test_nonergodic_generative_process.py b/tests/generative_processes/test_nonergodic_generative_process.py new file mode 100644 index 00000000..7b72edfa --- /dev/null +++ b/tests/generative_processes/test_nonergodic_generative_process.py @@ -0,0 +1,549 @@ +"""Tests for NonErgodicGenerativeProcess.""" + +# pylint: disable-all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from simplexity.generative_processes.builder import ( + build_factored_process_from_spec, + build_generalized_hidden_markov_model, + build_hidden_markov_model, + build_nonergodic_process_from_spec, +) +from simplexity.generative_processes.generator import generate_data_batch_with_full_history +from simplexity.generative_processes.nonergodic_generative_process import ( + ComponentState, + NonErgodicGenerativeProcess, + NonErgodicState, +) + + +def _expand_component_state( + state: ComponentState, + batch_size: int, +) -> ComponentState: + """Expand a single component state to a batch of identical states.""" + if isinstance(state, tuple): + return tuple(jnp.repeat(s[None, :], batch_size, axis=0) for s in state) + return jnp.repeat(state[None, :], batch_size, axis=0) + + +def _expand_state(state: NonErgodicState, batch_size: int) -> NonErgodicState: + """Expand a single NonErgodicState to a batch of identical states.""" + return NonErgodicState( + component_beliefs=jnp.repeat(state.component_beliefs[None, :], batch_size, axis=0), + component_states=tuple(_expand_component_state(cs, batch_size) for cs in state.component_states), + ) + + +class TestNonErgodicState: + """Tests for NonErgodicState structure.""" + + def test_state_is_named_tuple(self): + """NonErgodicState should be a NamedTuple with named fields.""" + state = NonErgodicState( + component_beliefs=jnp.array([0.5, 0.5]), + component_states=(jnp.array([1.0, 0.0]), jnp.array([0.5, 0.5])), + ) + assert hasattr(state, "component_beliefs") + assert hasattr(state, "component_states") + assert isinstance(state, tuple) + + def test_state_is_pytree_compatible(self): + """NonErgodicState should be compatible with JAX pytree operations.""" + state = NonErgodicState( + component_beliefs=jnp.array([0.5, 0.5]), + component_states=(jnp.array([1.0, 0.0]), jnp.array([0.5, 0.5])), + ) + # Should work with tree_map + doubled = jax.tree_util.tree_map(lambda x: x * 2, state) + chex.assert_trees_all_close(doubled.component_beliefs, jnp.array([1.0, 1.0])) + + +class TestNonErgodicGenerativeProcess: + """Tests for NonErgodicGenerativeProcess class.""" + + @pytest.fixture + def two_coin_process(self): + """Two biased coins as a nonergodic mixture.""" + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + return NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[0.6, 0.4], + ) + + def test_vocab_size_inferred_correctly(self, two_coin_process): + """Vocab size should be max of component vocab sizes.""" + assert two_coin_process.vocab_size == 2 + + def test_initial_state_has_correct_structure(self, two_coin_process): + """Initial state should have component beliefs and per-component states.""" + state = two_coin_process.initial_state + assert isinstance(state, NonErgodicState) + chex.assert_trees_all_close(state.component_beliefs, jnp.array([0.6, 0.4])) + assert len(state.component_states) == 2 + + def test_observation_distribution_is_mixture(self, two_coin_process): + """Observation dist should be weighted mixture of component dists.""" + state = two_coin_process.initial_state + dist = two_coin_process.observation_probability_distribution(state) + + # Expected: 0.6 * [0.7, 0.3] + 0.4 * [0.3, 0.7] = [0.54, 0.46] + expected = jnp.array([0.54, 0.46]) + chex.assert_trees_all_close(dist, expected, atol=1e-6) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_transition_updates_beliefs_correctly(self, two_coin_process): + """Observing a token should update component beliefs via Bayes rule.""" + state = two_coin_process.initial_state + + # Observe token 0 (heads) + new_state = two_coin_process.transition_states(state, jnp.array(0)) + + # Bayes update: P(comp | obs=0) proportional to P(obs=0 | comp) * P(comp) + # P(comp0 | obs=0) proportional to 0.7 * 0.6 = 0.42 + # P(comp1 | obs=0) proportional to 0.3 * 0.4 = 0.12 + # Normalized: [0.42, 0.12] / 0.54 = [0.778, 0.222] + expected_beliefs = jnp.array([0.42, 0.12]) + expected_beliefs = expected_beliefs / jnp.sum(expected_beliefs) + chex.assert_trees_all_close(new_state.component_beliefs, expected_beliefs, atol=1e-5) + + def test_probability_equals_mixture_probability(self, two_coin_process): + """P(sequence) should equal weighted sum of component probabilities.""" + observations = jnp.array([0, 0, 1]) # HHT + + prob = two_coin_process.probability(observations) + + # Manual calculation: + # P(HHT | coin1) = 0.7 * 0.7 * 0.3 = 0.147 + # P(HHT | coin2) = 0.3 * 0.3 * 0.7 = 0.063 + # P(HHT) = 0.6 * 0.147 + 0.4 * 0.063 = 0.0882 + 0.0252 = 0.1134 + expected = 0.6 * 0.147 + 0.4 * 0.063 + chex.assert_trees_all_close(prob, expected, atol=1e-6) + + def test_log_probability_consistent_with_probability(self, two_coin_process): + """log_probability should equal log of probability.""" + observations = jnp.array([0, 1, 0, 1]) + + prob = two_coin_process.probability(observations) + log_prob = two_coin_process.log_probability(observations) + + chex.assert_trees_all_close(log_prob, jnp.log(prob), atol=1e-5) + + def test_generate_produces_valid_sequences(self, two_coin_process): + """generate should produce sequences within vocab range.""" + state = two_coin_process.initial_state + # Batch the state + batch_size = 4 + batch_states = NonErgodicState( + component_beliefs=jnp.broadcast_to(state.component_beliefs, (batch_size,) + state.component_beliefs.shape), + component_states=tuple(jnp.broadcast_to(s, (batch_size,) + s.shape) for s in state.component_states), + ) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + + final_states, observations = two_coin_process.generate(batch_states, keys, 10, False) + + assert observations.shape == (batch_size, 10) + assert jnp.all(observations >= 0) + assert jnp.all(observations < two_coin_process.vocab_size) + + def test_emit_observation_within_vocab(self, two_coin_process): + """emit_observation should return valid tokens.""" + state = two_coin_process.initial_state + key = jax.random.PRNGKey(42) + + obs = two_coin_process.emit_observation(state, key) + + assert obs.shape == () + assert 0 <= int(obs) < two_coin_process.vocab_size + + +class TestVocabMaps: + """Tests for vocabulary mapping functionality.""" + + def test_different_vocab_maps_work(self): + """Components with different vocab maps should be handled correctly.""" + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + + process = NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[0.5, 0.5], + vocab_maps=[[0, 1], [0, 2]], # coin2 maps to tokens 0, 2 + ) + + assert process.vocab_size == 3 # tokens 0, 1, 2 + + state = process.initial_state + dist = process.observation_probability_distribution(state) + + # Token 0: both components can emit (0.5 * 0.7 + 0.5 * 0.3 = 0.5) + # Token 1: only component 0 (0.5 * 0.3 = 0.15) + # Token 2: only component 1 (0.5 * 0.7 = 0.35) + expected = jnp.array([0.5, 0.15, 0.35]) + chex.assert_trees_all_close(dist, expected, atol=1e-6) + + def test_unmapped_tokens_have_zero_probability(self): + """Tokens not in a component's vocab should contribute zero from that component.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + process = NonErgodicGenerativeProcess( + components=[coin], + component_weights=[1.0], + vocab_maps=[[0, 2]], # Component uses tokens 0, 2; token 1 is unmapped + ) + + state = process.initial_state + dist = process.observation_probability_distribution(state) + + assert process.vocab_size == 3 + assert dist[1] == 0.0 # Token 1 has zero probability + + +class TestMixedComponentTypes: + """Tests for mixing different GenerativeProcess types.""" + + def test_hmm_and_ghmm_mixture(self): + """Should handle mixing HMM and GHMM components.""" + hmm = build_hidden_markov_model("even_ones", {"p": 0.5}) + ghmm = build_generalized_hidden_markov_model("tom_quantum", {"alpha": 1.0, "beta": 1.0}) + + process = NonErgodicGenerativeProcess( + components=[hmm, ghmm], + component_weights=[0.7, 0.3], + ) + + state = process.initial_state + dist = process.observation_probability_distribution(state) + + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + assert jnp.all(dist >= 0) + + +class TestBuilder: + """Tests for build_nonergodic_process_from_spec.""" + + def test_build_from_hmm_specs(self): + """Should build process from HMM specifications.""" + process = build_nonergodic_process_from_spec( + components=[ + { + "component_type": "hmm", + "process_name": "coin", + "process_params": {"p": 0.6}, + }, + { + "component_type": "hmm", + "process_name": "coin", + "process_params": {"p": 0.4}, + }, + ], + component_weights=[0.5, 0.5], + ) + + assert isinstance(process, NonErgodicGenerativeProcess) + assert len(process.components) == 2 + assert process.vocab_size == 2 + + def test_build_from_ghmm_specs(self): + """Should build process from GHMM specifications.""" + process = build_nonergodic_process_from_spec( + components=[ + { + "component_type": "ghmm", + "process_name": "tom_quantum", + "process_params": {"alpha": 1.0, "beta": 1.0}, + }, + ], + component_weights=[1.0], + ) + + assert isinstance(process, NonErgodicGenerativeProcess) + assert len(process.components) == 1 + + def test_build_with_vocab_maps(self): + """Should respect vocab_maps in spec.""" + process = build_nonergodic_process_from_spec( + components=[ + { + "component_type": "hmm", + "process_name": "coin", + "process_params": {"p": 0.5}, + }, + { + "component_type": "hmm", + "process_name": "coin", + "process_params": {"p": 0.5}, + }, + ], + component_weights=[0.5, 0.5], + vocab_maps=[[0, 1], [0, 2]], + ) + + assert process.vocab_size == 3 + + def test_invalid_component_type_raises(self): + """Should raise for unknown component type.""" + with pytest.raises(ValueError, match="Unknown component_type"): + build_nonergodic_process_from_spec( + components=[{"component_type": "invalid", "process_name": "coin"}], + component_weights=[1.0], + ) + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_single_component_degenerates_to_component(self): + """Single-component process should behave like the component.""" + coin = build_hidden_markov_model("coin", {"p": 0.7}) + + process = NonErgodicGenerativeProcess( + components=[coin], + component_weights=[1.0], + ) + + observations = jnp.array([0, 1, 0]) + + process_prob = process.probability(observations) + coin_prob = coin.probability(observations) + + chex.assert_trees_all_close(process_prob, coin_prob, atol=1e-6) + + def test_weights_are_normalized(self): + """Component weights should be normalized to sum to 1.""" + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + + # Provide unnormalized weights + process = NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[2.0, 3.0], # Sum to 5, not 1 + ) + + chex.assert_trees_all_close(process.component_weights, jnp.array([0.4, 0.6]), atol=1e-6) + + def test_empty_components_raises(self): + """Should raise for empty component list.""" + with pytest.raises(ValueError, match="at least one component"): + NonErgodicGenerativeProcess( + components=[], + component_weights=[], + ) + + def test_mismatched_weights_raises(self): + """Should raise if weights don't match component count.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + with pytest.raises(ValueError, match="must match"): + NonErgodicGenerativeProcess( + components=[coin, coin], + component_weights=[1.0], # Only 1 weight for 2 components + ) + + def test_mismatched_vocab_maps_raises(self): + """Should raise if vocab map count doesn't match component count.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + with pytest.raises(ValueError, match="Length of vocab maps"): + NonErgodicGenerativeProcess( + components=[coin, coin], + component_weights=[0.5, 0.5], + vocab_maps=[[0, 1]], + ) + + def test_duplicate_vocab_map_entries_raise(self): + """Should raise if a component vocab map reuses a global token index.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + with pytest.raises(ValueError, match="must not contain duplicate"): + NonErgodicGenerativeProcess( + components=[coin], + component_weights=[1.0], + vocab_maps=[[0, 0]], + ) + + +class TestGenerateReturnAllStates: + """Tests for generate with return_all_states=True.""" + + @pytest.fixture + def two_mess3_process(self): + """Two mess3 HMMs as a nonergodic mixture.""" + hmm1 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + hmm2 = build_hidden_markov_model("mess3", {"x": 0.5, "a": 0.6}) + return NonErgodicGenerativeProcess( + components=[hmm1, hmm2], + component_weights=[0.6, 0.4], + ) + + def test_return_all_states_shapes(self, two_mess3_process): + """Both component_beliefs and component_states should have time dimension.""" + batch_size = 4 + seq_len = 8 + state = two_mess3_process.initial_state + batch_states = NonErgodicState( + component_beliefs=jnp.broadcast_to(state.component_beliefs, (batch_size,) + state.component_beliefs.shape), + component_states=tuple(jnp.broadcast_to(s, (batch_size,) + s.shape) for s in state.component_states), + ) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + + trajectory, observations = two_mess3_process.generate(batch_states, keys, seq_len, True) + + assert observations.shape == (batch_size, seq_len) + assert trajectory.component_beliefs.shape == (batch_size, seq_len, 2) + for i, comp in enumerate(two_mess3_process.components): + assert trajectory.component_states[i].shape == (batch_size, seq_len, comp.initial_state.shape[0]) + + def test_return_all_states_beliefs_are_valid_distributions(self, two_mess3_process): + """Component beliefs at each timestep should sum to 1.""" + batch_size = 4 + seq_len = 8 + state = two_mess3_process.initial_state + batch_states = NonErgodicState( + component_beliefs=jnp.broadcast_to(state.component_beliefs, (batch_size,) + state.component_beliefs.shape), + component_states=tuple(jnp.broadcast_to(s, (batch_size,) + s.shape) for s in state.component_states), + ) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + + trajectory, _ = two_mess3_process.generate(batch_states, keys, seq_len, True) + + belief_sums = jnp.sum(trajectory.component_beliefs, axis=-1) + chex.assert_trees_all_close(belief_sums, jnp.ones_like(belief_sums), atol=1e-5) + + +class TestFactoredComponent: + """Tests for FactoredGenerativeProcess as a NonErgodic component.""" + + @pytest.fixture + def hmm_factored_process(self): + """NonErgodic process with one HMM and one factored component.""" + hmm = build_hidden_markov_model("coin", {"p": 0.7}) + factored = build_factored_process_from_spec( + structure_type="independent", + spec=[ + {"component_type": "hmm", "variants": [{"process_name": "coin", "process_params": {"p": 0.6}}]}, + {"component_type": "hmm", "variants": [{"process_name": "coin", "process_params": {"p": 0.4}}]}, + ], + ) + return NonErgodicGenerativeProcess( + components=[hmm, factored], + component_weights=[0.5, 0.5], + ) + + def test_factored_component_generate(self, hmm_factored_process): + """NonErgodic with a factored component should generate valid sequences.""" + process = hmm_factored_process + batch_size = 4 + seq_len = 6 + batch_states = _expand_state(process.initial_state, batch_size) + keys = jax.random.split(jax.random.PRNGKey(42), batch_size) + + final_states, observations = process.generate(batch_states, keys, seq_len, False) + + assert observations.shape == (batch_size, seq_len) + assert jnp.all(observations >= 0) + assert jnp.all(observations < process.vocab_size) + + def test_factored_component_return_all_states(self, hmm_factored_process): + """Factored component state trajectory should have correct shapes.""" + process = hmm_factored_process + + batch_size = 4 + seq_len = 6 + batch_states = _expand_state(process.initial_state, batch_size) + keys = jax.random.split(jax.random.PRNGKey(42), batch_size) + + trajectory, observations = process.generate(batch_states, keys, seq_len, True) + + assert observations.shape == (batch_size, seq_len) + assert trajectory.component_beliefs.shape == (batch_size, seq_len, 2) + # HMM component state: flat array + assert trajectory.component_states[0].ndim == 3 # [batch, seq, state_dim] + # Factored component state: tuple of arrays + assert isinstance(trajectory.component_states[1], tuple) + for factor_state in trajectory.component_states[1]: + assert factor_state.ndim == 3 # [batch, seq, factor_dim] + + +class TestGenerateDataBatchWithFullHistory: + """Tests for generate_data_batch_with_full_history with NonErgodicGenerativeProcess.""" + + def test_full_history_shapes(self): + """Belief states should have consistent shapes after slicing.""" + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + process = NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[0.6, 0.4], + ) + + batch_size = 4 + seq_len = 8 + batch_states = _expand_state(process.initial_state, batch_size) + + result = generate_data_batch_with_full_history( + batch_states, # type: ignore[arg-type] + process, + batch_size, + seq_len, + jax.random.PRNGKey(0), + ) + + belief_states = result["belief_states"] + inputs = result["inputs"] + assert isinstance(inputs, jax.Array) + + assert isinstance(belief_states, NonErgodicState) + input_len = inputs.shape[1] + assert belief_states.component_beliefs.shape == (batch_size, input_len, 2) + for cs in belief_states.component_states: + assert not isinstance(cs, tuple) + assert cs.shape[0] == batch_size + assert cs.shape[1] == input_len + + def test_full_history_with_bos(self): + """Belief states should align with inputs when BOS token is used.""" + coin1 = build_hidden_markov_model("coin", {"p": 0.7}) + coin2 = build_hidden_markov_model("coin", {"p": 0.3}) + process = NonErgodicGenerativeProcess( + components=[coin1, coin2], + component_weights=[0.6, 0.4], + ) + + batch_size = 4 + seq_len = 8 + bos_token = process.vocab_size + batch_states = _expand_state(process.initial_state, batch_size) + + result = generate_data_batch_with_full_history( + batch_states, # type: ignore[arg-type] + process, + batch_size, + seq_len, + jax.random.PRNGKey(0), + bos_token=bos_token, + ) + + belief_states = result["belief_states"] + inputs = result["inputs"] + assert isinstance(inputs, jax.Array) + + assert isinstance(belief_states, NonErgodicState) + input_len = inputs.shape[1] + assert belief_states.component_beliefs.shape == (batch_size, input_len, 2) + for cs in belief_states.component_states: + assert not isinstance(cs, tuple) + assert cs.shape[0] == batch_size + assert cs.shape[1] == input_len diff --git a/walkthroughs/pr-172.json b/walkthroughs/pr-172.json new file mode 100644 index 00000000..2fe60f54 --- /dev/null +++ b/walkthroughs/pr-172.json @@ -0,0 +1,157 @@ +{ + "title": "PR #172: NonErgodicGenerativeProcess & InflatedVocabularyProcess", + "description": "Walkthrough of two new generative process types: a block-diagonal nonergodic mixture model and a vocabulary inflation wrapper, plus builder functions and comprehensive tests.", + "repository": { + "remote": "https://github.com/Astera-org/simplexity.git", + "commit": "HEAD" + }, + "metadata": { + "pr": 172, + "recommendation": "approve" + }, + "steps": [ + { + "id": 1, + "title": "Overview: Two new generative process abstractions", + "body": "This PR introduces two new `GenerativeProcess` subclasses:\n\n1. **NonErgodicGenerativeProcess** — A block-diagonal mixture model that composes multiple `GenerativeProcess` components with weighted probabilities. No transitions occur between components; beliefs are updated via Bayesian filtering.\n\n2. **InflatedVocabularyProcess** — A wrapper that multiplies vocabulary size by a factor K with uniform noise, increasing optimal per-token loss by exactly `log(K)` nats.\n\nThe PR also adds 9 builder functions for constructing these processes from YAML specs, supporting disjoint and partially-overlapping vocabulary configurations with three mapping strategies (prefix, sliding, random).\n\nKey design choices:\n- Does NOT materialize a full block-diagonal matrix — stores component processes directly for efficiency\n- Uses `IndependentFactoredGenerativeProcess` for independent structures to achieve O(sum V_i) sampling complexity\n- Handles heterogeneous state types (HMM vs Factored) via flatten/pad/unflatten for `jax.lax.switch` compatibility" + }, + { + "id": 2, + "title": "NonErgodicState: the composite state representation", + "body": "The state is a `NamedTuple` with two fields:\n- `component_beliefs`: a probability distribution over components, shape `[num_components]`. During generation this becomes one-hot after the first emission; during inference it's updated via Bayes rule.\n- `component_states`: a tuple of per-component states, where each element can be either a flat `jax.Array` (HMM) or a tuple of arrays (FactoredState).\n\nThis heterogeneous state design is central to the PR — it allows mixing different process types (HMM, GHMM, Factored) in a single mixture.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:78-90" + }, + { + "id": 3, + "title": "State flattening utilities for JAX compatibility", + "body": "Since `jax.lax.switch` requires all branches to return identically-shaped arrays, these three helper functions flatten heterogeneous component states into uniform 1D arrays:\n\n- `_get_flat_size`: counts total elements\n- `_flatten_state`: concatenates to 1D\n- `_unflatten_state`: reconstructs original structure using a template\n\nNote the use of `jax.lax.dynamic_slice` instead of Python slicing in `_unflatten_state` (line 71) — this avoids `ConcretizationTypeError` inside `jax.lax.switch` since template shapes are known at trace time.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:20-75" + }, + { + "id": 4, + "title": "Constructor: vocab maps and inverse maps", + "body": "The constructor normalizes component weights, builds forward vocab maps (local-to-global), and computes inverse maps (global-to-local) for efficient observation routing.\n\nKey details:\n- Weights are normalized to sum to 1 (line 156)\n- If no vocab maps provided, each component gets identity mapping `[0..V-1]` (line 160)\n- The unified vocab size is `max(all global tokens) + 1` (line 163)\n- Inverse maps use `-1` sentinel for unmapped tokens (line 167), which `transition_states` checks to determine if an observation belongs to a component", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:123-171", + "comments": [ + { + "id": "mmtj452t4ne", + "author": "Eric Alt", + "body": "```python\nif vocab_maps is None:\n vocab_maps = [list(range(c.vocab_size)) for c in components]\nelif len(vocab_maps) != len(self.components):\n raise ValueError(\"Length of vocab maps must equal length of components.\")\n```" + }, + { + "id": "mmtjm7a5e79", + "author": "Eric Alt", + "body": "```python\n Raises:\n ValueError: If components is empty, weights don't match component count,\n or a component vocab_map contains duplicate global token indices\n\n...\n\n for i, vm in enumerate(vocab_maps):\n if len(set(vm)) != len(vm):\n raise ValueError(f\"vocab_maps[{i}] must not contain duplicate global token indices\")\n```\n\nCorresponding unit tests should also be added" + } + ] + }, + { + "id": 5, + "title": "Observation distribution: weighted mixture over components", + "body": "Computes `P(obs | state) = sum_i P(component_i | state) * P(obs | component_i, state_i)`.\n\nFor each component: gets the local distribution, scatters it into global vocab space via `vocab_maps[i]`, and weights by `component_beliefs[i]`. Tokens not in a component's vocab naturally get probability 0.\n\nThe log-space variant (line 208-228) uses `logsumexp` across stacked per-component log distributions for numerical stability, with `-inf` for unmapped tokens.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:186-228", + "comments": [ + { + "id": "mmtj5fpu16p", + "author": "Eric Alt", + "body": "`len(self.components)` should always equal `len(self.vocab_maps)` so `strict` should be `True` instead of `False`" + }, + { + "id": "mmtj8i7vy18", + "author": "Eric Alt", + "body": "```python\nglobal_dist += state.component_beliefs[i] * component_contrib\n```" + } + ] + }, + { + "id": 6, + "title": "Bayesian filtering in transition_states", + "body": "This is the core inference logic. For each observation:\n\n1. Map global token to each component's local space via inverse vocab maps (line 268)\n2. Compute likelihood `P(obs | component_i)` — 0 if token not in component's vocab (lines 271-275)\n3. Conditionally update each component's internal state only when likelihood > 0 (lines 278-284)\n4. Apply Bayes rule: `new_beliefs = beliefs * likelihoods / normalizer` (lines 288-294)\n5. Fall back to prior beliefs if all likelihoods are 0 (line 291-293)\n\nThe `jax.lax.cond` on line 278 avoids unnecessary state transitions for components that couldn't have generated the observation.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:253-299", + "comments": [ + { + "id": "mmtjcmycfty", + "author": "Eric Alt", + "body": "`strict=True`" + } + ] + }, + { + "id": 7, + "title": "Generation: sample one component, generate entire sequence", + "body": "Unlike inference (which tracks beliefs across all components), generation samples a single component at the start and generates entirely from it.\n\nThe implementation is notable for its complexity:\n- Cannot delegate to `component.generate()` because that method is also vmapped (line 363)\n- Uses flatten/pad to a common max size so `jax.lax.switch` can handle heterogeneous state types (lines 389-398)\n- `scan_step` (line 411) runs generation via `lax.switch` selecting the active component\n- Only the active component's state is updated per step (lines 424-431)\n- When `return_all_states=True`, a second inference scan reconstructs belief trajectories (lines 447-453)\n\nThis is the most intricate part of the PR — the flatten/pad/unflatten dance is the price paid for supporting mixed component types in a single JIT-compiled scan.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:347-459", + "comments": [ + { + "id": "mmtlfwdo7hn", + "author": "Eric Alt", + "body": "TODO: instead of having generative processes's `generate` function vmapped by default we should just have it function for a single sequence and define a separete `generate_batch` function that just wraps the generate function in a `vmap` (or just require the caller of generate to do that themselves) - outside the scope of this PR though" + } + ] + }, + { + "id": 8, + "title": "InflatedVocabularyProcess: controlled difficulty via noise", + "body": "A clean decorator pattern that wraps any `GenerativeProcess[State]` to inflate its vocabulary.\n\nToken encoding: `inflated_token = noise_prefix * V_base + base_token`\n- `emit_observation` samples a base token then adds a uniform random noise prefix (lines 68-71)\n- `transition_states` extracts the base token via modulo and discards noise (line 76)\n- Probability distributions are tiled K times and divided by K (line 83)\n- `probability` applies a `(1/K)^T` penalty for a sequence of length T (line 96)\n\nThis elegantly increases optimal per-token loss by exactly `log(K)` nats while preserving all state dynamics.", + "location": "simplexity/generative_processes/inflated_vocabulary_process.py:22-103" + }, + { + "id": 9, + "title": "Generator updates: slicing NonErgodicState belief trajectories", + "body": "The existing `generate_data_batch_with_full_history` function needed to handle `NonErgodicState` when slicing belief trajectories along the sequence dimension.\n\nA new `_slice_belief_states` helper (line 96) handles three state representations:\n- Plain arrays: slice directly\n- Tuples of arrays: slice each element\n- `NonErgodicState`: slice both `component_beliefs` and each entry in `component_states`, handling nested tuples for factored components (line 111)\n\nThis replaces the previous inline isinstance/tuple handling with a cleaner dispatch.", + "location": "simplexity/generative_processes/generator.py:96-118" + }, + { + "id": 10, + "title": "IndependentFactoredGenerativeProcess: noise_epsilon passthrough", + "body": "A small but important change: `noise_epsilon` is added as a constructor parameter (line 52) and forwarded to the parent `FactoredGenerativeProcess.__init__` (line 83).\n\nThis allows nonergodic processes to compose factored components that use noisy channels — previously the `IndependentFactoredGenerativeProcess` would silently ignore this parameter.", + "location": "simplexity/generative_processes/independent_factored_generative_process.py:43-84" + }, + { + "id": 11, + "title": "Builder: component factory and nonergodic process construction", + "body": "The private `_build_components_from_spec` helper (line 653) is the foundation for all nonergodic builders. It dispatches on `component_type` to build HMM, GHMM, or Factored processes.\n\n`build_nonergodic_process_from_spec` (line 700) is the main entry point, documented with a full YAML example showing how to compose HMM, GHMM, and factored components with explicit vocab maps.\n\nNote how vocab maps can be specified either per-component in the spec or globally as an override (lines 754-761).", + "location": "simplexity/generative_processes/builder.py:653-768" + }, + { + "id": 12, + "title": "Builder: vocabulary mapping strategies", + "body": "Three vocab map strategies for partially overlapping alphabets:\n\n1. **Prefix** (line 805): C0 gets `[0..V-1]`, subsequent components share a prefix of `n_shared` tokens plus unique tokens above V\n2. **Sliding** (line 820): Each component's vocab slides by `max(1, n_unique)` tokens — simple offset strategy\n3. **Random** (line 829): Each component independently samples V tokens from the global pool using a seeded RNG\n\n`build_nonergodic_partial_overlap` (line 842) orchestrates these, computing `n_shared = int(V * overlap_frac)` and `n_unique = V - n_shared` from the `overlap_frac` parameter. All components must have equal vocab size for this to work (validated on line 875-876).\n\n`build_nonergodic_disjoint_vocab` (line 771) is the simpler case: sequential non-overlapping ranges `[0..V0-1], [V0..V0+V1-1], ...`", + "location": "simplexity/generative_processes/builder.py:771-897", + "comments": [ + { + "id": "mmtncbw77fd", + "author": "Eric Alt", + "body": "```python\n def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int) -> list[list[int]]:\n \"\"\"Build vocab maps using the prefix strategy.\"\"\"\n return [list(range(v))] + [\n list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique))\n for i in range(n_components - 1)\n ]\n```" + }, + { + "id": "mmtngqqdqwf", + "author": "Eric Alt", + "body": "`assert seed is not None` is redundant with earlier check" + }, + { + "id": "mmtnjikp3yy", + "author": "Eric Alt", + "body": "```python\n def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]:\n \"\"\"Build vocab maps by having each component randomly sample V tokens from the global pool.\n\n The global vocab size is the same as in prefix mode:\n V + (n_components - 1) * n_unique.\n \"\"\"\n global_vocab_size = v + (n_components - 1) * n_unique\n rng = random.Random(seed)\n\n return [\n sorted(rng.sample(range(global_vocab_size), v))\n for _ in range(n_components)\n ]\n```" + } + ] + }, + { + "id": 13, + "title": "Builder: InflatedVocabularyProcess construction", + "body": "Two builder functions for the inflation wrapper:\n\n- `build_inflated_process` (line 900): Simple wrapper taking an existing `GenerativeProcess` and inflation factor\n- `build_inflated_process_from_spec` (line 916): Builds the base process from a spec dict first, then wraps it — supports HMM, GHMM, and factored base types\n\nBoth are thin wrappers that delegate to the `InflatedVocabularyProcess` constructor.", + "location": "simplexity/generative_processes/builder.py:900-959" + }, + { + "id": 14, + "title": "build_factored_process now returns IndependentFactoredGenerativeProcess", + "body": "A structural change in the existing `build_factored_process` function: when `structure_type == \"independent\"`, it now returns an `IndependentFactoredGenerativeProcess` directly (line 200) with early return, rather than falling through to the generic `FactoredGenerativeProcess` constructor.\n\nThis ensures that nonergodic processes composing independent factored components get the specialized subclass with per-factor sampling and frozen factor support, plus the new `noise_epsilon` passthrough.", + "location": "simplexity/generative_processes/builder.py:198-207" + }, + { + "id": 15, + "title": "Summary and recommendations", + "body": "**Strengths:**\n- Clean abstractions: both new classes implement the full `GenerativeProcess` protocol\n- The flatten/pad/unflatten approach for heterogeneous states in `jax.lax.switch` is well-documented and correct\n- Comprehensive builder functions with three vocab mapping strategies cover real research use cases\n- Excellent test coverage (21+ tests for NonErgodic, tests for Inflated, builder tests)\n- The `InflatedVocabularyProcess` is a particularly elegant design — stateless noise with provable loss increase\n\n**Architecture notes:**\n- The `generate` method in `NonErgodicGenerativeProcess` is the most complex piece — the re-implementation of the generate loop (rather than delegating to components) is necessary due to vmap constraints but adds maintenance burden\n- The `_slice_belief_states` helper in `generator.py` adds a third branch for `NonErgodicState` — if more state types emerge, this could benefit from a protocol-based dispatch\n\n**Recommendation: Approve** — Well-structured addition with solid test coverage and clear documentation of the tricky JAX compatibility patterns." + } + ] +} \ No newline at end of file