diff --git a/simplexity/generative_processes/builder.py b/simplexity/generative_processes/builder.py index 0dbb8f3a..ddcf1851 100644 --- a/simplexity/generative_processes/builder.py +++ b/simplexity/generative_processes/builder.py @@ -807,14 +807,9 @@ def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: C0 gets [0..V-1]. Ci>0 gets shared [0..n_shared-1] + unique tokens above V. """ - vocab_maps: list[list[int]] = [] - for i in range(n_components): - if i == 0: - vocab_maps.append(list(range(v))) - else: - unique_start = v + (i - 1) * n_unique - vocab_maps.append(list(range(n_shared)) + list(range(unique_start, unique_start + n_unique))) - return vocab_maps + return [list(range(v))] + [ + list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique)) for i in range(n_components - 1) + ] def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[list[int]]: @@ -826,17 +821,15 @@ def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[ return [list(range(i * offset, i * offset + v)) for i in range(n_components)] -def _build_random_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int, seed: int) -> list[list[int]]: +def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]: """Build vocab maps by having each component randomly sample V tokens from the global pool. - The global vocab size is the same as in prefix mode (V + (n_components - 1) * n_unique), - and each component independently samples V tokens without replacement. + The global vocab size is the same as in prefix mode: + V + (n_components - 1) * n_unique. """ - prefix_maps = _build_prefix_vocab_maps(n_components, v, n_shared, n_unique) - global_vocab_size = max(max(vm) for vm in prefix_maps) + 1 + global_vocab_size = v + (n_components - 1) * n_unique rng = random.Random(seed) - global_tokens = list(range(global_vocab_size)) - return [sorted(rng.sample(global_tokens, v)) for _ in range(n_components)] + return [sorted(rng.sample(range(global_vocab_size), v)) for _ in range(n_components)] def build_nonergodic_partial_overlap( @@ -884,8 +877,9 @@ def build_nonergodic_partial_overlap( elif mode == "sliding": vocab_maps = _build_sliding_vocab_maps(n_components, v, n_unique) elif mode == "random": - assert seed is not None - vocab_maps = _build_random_vocab_maps(n_components, v, n_shared, n_unique, seed) + if seed is None: + raise ValueError("seed is required when mode='random'") + vocab_maps = _build_random_vocab_maps(n_components, v, n_unique, seed) else: raise ValueError(f"Unknown mode '{mode}'. Must be 'prefix', 'sliding', or 'random'.") diff --git a/simplexity/generative_processes/nonergodic_generative_process.py b/simplexity/generative_processes/nonergodic_generative_process.py index 0be9f1a4..a45eabf7 100644 --- a/simplexity/generative_processes/nonergodic_generative_process.py +++ b/simplexity/generative_processes/nonergodic_generative_process.py @@ -17,6 +17,14 @@ ComponentState = jax.Array | tuple[jax.Array, ...] +class _GenerationLayout(NamedTuple): + """Static metadata needed to flatten and restore heterogeneous states.""" + + flat_sizes: tuple[int, ...] + state_templates: tuple[ComponentState, ...] + max_flat_size: int + + def _get_flat_size(state: ComponentState) -> int: """Get total number of elements in a component state. @@ -45,6 +53,12 @@ def _flatten_state(state: ComponentState) -> jax.Array: return state.ravel() +def _flatten_and_pad_state(state: ComponentState, max_flat_size: int) -> jax.Array: + """Flatten a state and pad it to the shared switch-compatible size.""" + flat = _flatten_state(state) + return jnp.pad(flat, (0, max_flat_size - flat.size)) + + def _unflatten_state(flat: jax.Array, template: ComponentState) -> ComponentState: """Restore original state structure from a flattened 1D array. @@ -75,6 +89,16 @@ def _unflatten_state(flat: jax.Array, template: ComponentState) -> ComponentStat return flat.reshape(template.shape) +def _unpad_and_unflatten_state(padded: jax.Array, original_size: int, template: ComponentState) -> ComponentState: + """Remove padding and restore the component state structure.""" + return _unflatten_state(padded[:original_size], template) + + +def _keep_state(state: ComponentState, _obs: chex.Array) -> ComponentState: + """Return the existing state unchanged.""" + return state + + class NonErgodicState(NamedTuple): """State for nonergodic generative process. @@ -138,7 +162,9 @@ def __init__( device: Device to place arrays on (e.g., "cpu", "gpu"). Raises: - ValueError: If components is empty or weights don't match component count. + ValueError: If components is empty, weights don't match component count, + vocab map count doesn't match component count, or a component + vocab_map contains duplicate global token indices. """ if len(components) == 0: raise ValueError("Must provide at least one component") @@ -158,6 +184,12 @@ def __init__( if vocab_maps is None: vocab_maps = [list(range(c.vocab_size)) for c in components] + elif len(vocab_maps) != len(self.components): + raise ValueError("Length of vocab maps must equal length of components.") + + for i, vm in enumerate(vocab_maps): + if len(set(vm)) != len(vm): + raise ValueError(f"vocab_maps[{i}] must not contain duplicate global token indices") self.vocab_maps = tuple(jax.device_put(jnp.array(vm, dtype=jnp.int32), self.device) for vm in vocab_maps) self._vocab_size = max(max(vm) for vm in vocab_maps) + 1 @@ -197,11 +229,11 @@ def observation_probability_distribution(self, state: NonErgodicState) -> jax.Ar """ global_dist = jnp.zeros(self._vocab_size) - for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=False)): + for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=True)): comp_state = state.component_states[i] local_dist = component.observation_probability_distribution(comp_state) component_contrib = jnp.zeros(self._vocab_size).at[vm].add(local_dist) - global_dist = global_dist + state.component_beliefs[i] * component_contrib + global_dist += state.component_beliefs[i] * component_contrib return global_dist @@ -215,7 +247,7 @@ def log_observation_probability_distribution(self, log_belief_state: NonErgodicS """ log_probs = [] - for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=False)): + for i, (component, vm) in enumerate(zip(self.components, self.vocab_maps, strict=True)): comp_log_state = log_belief_state.component_states[i] comp_log_belief = log_belief_state.component_beliefs[i] @@ -250,6 +282,34 @@ def emit_from_component(i: int, k: chex.PRNGKey) -> chex.Array: return global_obs + def _update_component_for_observation( + self, + component: GenerativeProcess, + inv_map: jax.Array, + comp_state: ComponentState, + obs: chex.Array, + ) -> tuple[ComponentState, jax.Array]: + """Update one component state and return its observation likelihood.""" + local_obs = inv_map[obs] + local_dist = component.observation_probability_distribution(comp_state) + likelihood = jnp.where( + local_obs >= 0, + local_dist[jnp.clip(local_obs, 0, local_dist.shape[0] - 1)], + 0.0, + ) + + def transition_component(state: ComponentState, mapped_obs: chex.Array) -> ComponentState: + return component.transition_states(state, mapped_obs) + + new_comp_state = jax.lax.cond( + likelihood > 0, + transition_component, + _keep_state, + comp_state, + local_obs, + ) + return new_comp_state, likelihood + @eqx.filter_jit def transition_states(self, state: NonErgodicState, obs: chex.Array) -> NonErgodicState: """Update state given observation using Bayesian filtering. @@ -263,25 +323,10 @@ def transition_states(self, state: NonErgodicState, obs: chex.Array) -> NonErgod new_component_states = [] likelihoods = [] - for i, (component, inv_map) in enumerate(zip(self.components, self._inverse_vocab_maps, strict=False)): + for i, (component, inv_map) in enumerate(zip(self.components, self._inverse_vocab_maps, strict=True)): comp_state = state.component_states[i] - local_obs = inv_map[obs] - - local_dist = component.observation_probability_distribution(comp_state) - likelihood = jnp.where( - local_obs >= 0, - local_dist[jnp.clip(local_obs, 0, local_dist.shape[0] - 1)], - 0.0, - ) + new_comp_state, likelihood = self._update_component_for_observation(component, inv_map, comp_state, obs) likelihoods.append(likelihood) - - new_comp_state = jax.lax.cond( - likelihood > 0, - lambda s, lo, c=component: c.transition_states(s, lo), - lambda s, lo, c=None: s, - comp_state, - local_obs, - ) new_component_states.append(new_comp_state) likelihoods_arr = jnp.array(likelihoods) @@ -310,9 +355,13 @@ def compute_component_prob(i: int) -> jax.Array: inv_map = self._inverse_vocab_maps[i] local_obs = inv_map[observations] all_valid = jnp.all(local_obs >= 0) + + def compute_prob(lo: jax.Array) -> jax.Array: + return component.probability(lo) + prob = jax.lax.cond( all_valid, - lambda lo: component.probability(lo), + compute_prob, lambda lo: jnp.array(0.0), local_obs, ) @@ -333,9 +382,13 @@ def compute_component_log_prob(i: int) -> jax.Array: inv_map = self._inverse_vocab_maps[i] local_obs = inv_map[observations] all_valid = jnp.all(local_obs >= 0) + + def compute_log_prob(lo: jax.Array) -> jax.Array: + return component.log_probability(lo) + log_prob = jax.lax.cond( all_valid, - lambda lo: component.log_probability(lo), + compute_log_prob, lambda lo: jnp.array(-jnp.inf), local_obs, ) @@ -344,6 +397,72 @@ def compute_component_log_prob(i: int) -> jax.Array: log_probs = jnp.array([compute_component_log_prob(i) for i in range(len(self.components))]) return jax.nn.logsumexp(log_probs) + def _generate_component_step( + self, + i: int, + padded_state: jax.Array, + step_key: chex.PRNGKey, + layout: _GenerationLayout, + ) -> tuple[jax.Array, chex.Array]: + """Advance one selected component by a single generation step.""" + real_state = _unpad_and_unflatten_state(padded_state, layout.flat_sizes[i], layout.state_templates[i]) + local_obs = self.components[i].emit_observation(real_state, step_key) + new_real_state = self.components[i].transition_states(real_state, local_obs) + new_padded_state = _flatten_and_pad_state(new_real_state, layout.max_flat_size) + global_obs = self.vocab_maps[i][local_obs] + return new_padded_state, global_obs + + def _scan_component_generation( + self, + component_idx: jax.Array, + padded_states: tuple[jax.Array, ...], + keys: jax.Array, + layout: _GenerationLayout, + ) -> tuple[tuple[jax.Array, ...], chex.Array]: + """Generate observations while updating only the sampled component state.""" + num_components = len(self.components) + + def scan_step( + carry: tuple[jax.Array, tuple[jax.Array, ...]], step_key: chex.PRNGKey + ) -> tuple[tuple[jax.Array, tuple[jax.Array, ...]], chex.Array]: + idx, padded_comp_states = carry + + new_padded_state, global_obs = jax.lax.switch( + idx, + [ + partial( + self._generate_component_step, + i, + padded_comp_states[i], + step_key, + layout, + ) + for i in range(num_components) + ], + ) + + new_padded_comp_states = tuple( + jax.lax.select(idx == i, new_padded_state, padded_comp_states[i]) for i in range(num_components) + ) + + return (idx, new_padded_comp_states), global_obs + + init_carry = (component_idx, padded_states) + (_, final_padded_states), observations = jax.lax.scan(scan_step, init_carry, keys) + return final_padded_states, observations + + def _generate_state_trajectory( + self, state: NonErgodicState, observations: chex.Array + ) -> tuple[NonErgodicState, chex.Array]: + """Reconstruct the per-token belief trajectory from generated observations.""" + + def inference_step(carry_state: NonErgodicState, obs: chex.Array) -> tuple[NonErgodicState, NonErgodicState]: + new_state = self.transition_states(carry_state, obs) + return new_state, carry_state + + _, state_trajectory = jax.lax.scan(inference_step, state, observations) + return state_trajectory, observations + @eqx.filter_vmap(in_axes=(None, 0, 0, None, None)) def generate( self, @@ -383,77 +502,25 @@ def generate( keys = jax.random.split(key2, sequence_len) component_idx = jax.random.categorical(key1, jnp.log(state.component_beliefs)) - - num_components = len(self.components) - state_templates = state.component_states - flat_sizes = [_get_flat_size(s) for s in state_templates] - max_flat_size = max(flat_sizes) - - def flatten_and_pad(s: ComponentState) -> jax.Array: - flat = _flatten_state(s) - return jnp.pad(flat, (0, max_flat_size - flat.size)) - - def unpad_and_unflatten(padded: jax.Array, original_size: int, template: ComponentState) -> ComponentState: - return _unflatten_state(padded[:original_size], template) - - padded_states = tuple(flatten_and_pad(s) for s in state.component_states) - - def gen_step_for_component( - i: int, padded_state: jax.Array, step_key: chex.PRNGKey - ) -> tuple[jax.Array, chex.Array]: - real_state = unpad_and_unflatten(padded_state, flat_sizes[i], state_templates[i]) - local_obs = self.components[i].emit_observation(real_state, step_key) - new_real_state = self.components[i].transition_states(real_state, local_obs) - new_padded_state = flatten_and_pad(new_real_state) - global_obs = self.vocab_maps[i][local_obs] - return new_padded_state, global_obs - - def scan_step( - carry: tuple[jax.Array, tuple[jax.Array, ...]], step_key: chex.PRNGKey - ) -> tuple[tuple[jax.Array, tuple[jax.Array, ...]], chex.Array]: - idx, padded_comp_states = carry - - def gen_from_i(i: int) -> tuple[jax.Array, chex.Array]: - return gen_step_for_component(i, padded_comp_states[i], step_key) - - new_padded_state, global_obs = jax.lax.switch( - idx, - [partial(gen_from_i, i) for i in range(num_components)], - ) - - new_padded_comp_states = tuple( - jax.lax.cond( - idx == i, - lambda ns=new_padded_state: ns, - lambda ps=padded_comp_states[i]: ps, - ) - for i in range(num_components) - ) - - return (idx, new_padded_comp_states), global_obs - - init_carry = (component_idx, padded_states) - (_, final_padded_states), observations = jax.lax.scan(scan_step, init_carry, keys) + layout = _GenerationLayout( + flat_sizes=tuple(_get_flat_size(s) for s in state.component_states), + state_templates=state.component_states, + max_flat_size=max(_get_flat_size(s) for s in state.component_states), + ) + padded_states = tuple(_flatten_and_pad_state(s, layout.max_flat_size) for s in state.component_states) + final_padded_states, observations = self._scan_component_generation(component_idx, padded_states, keys, layout) final_comp_states = tuple( - unpad_and_unflatten(final_padded_states[i], flat_sizes[i], state_templates[i]) - for i in range(num_components) + _unpad_and_unflatten_state(final_padded_states[i], layout.flat_sizes[i], layout.state_templates[i]) + for i in range(len(self.components)) ) one_hot_beliefs = jax.nn.one_hot(component_idx, len(self.components), dtype=self.component_weights.dtype) if return_all_states: + return self._generate_state_trajectory(state, observations) - def inference_step( - carry_state: NonErgodicState, obs: chex.Array - ) -> tuple[NonErgodicState, NonErgodicState]: - new_state = self.transition_states(carry_state, obs) - return new_state, carry_state - - _, state_trajectory = jax.lax.scan(inference_step, state, observations) - return state_trajectory, observations - else: - return NonErgodicState( - component_beliefs=one_hot_beliefs, - component_states=final_comp_states, - ), observations + return NonErgodicState( + component_beliefs=one_hot_beliefs, + component_states=final_comp_states, + ), observations diff --git a/tests/generative_processes/test_nonergodic_generative_process.py b/tests/generative_processes/test_nonergodic_generative_process.py index 51e93e47..7b72edfa 100644 --- a/tests/generative_processes/test_nonergodic_generative_process.py +++ b/tests/generative_processes/test_nonergodic_generative_process.py @@ -352,6 +352,28 @@ def test_mismatched_weights_raises(self): component_weights=[1.0], # Only 1 weight for 2 components ) + def test_mismatched_vocab_maps_raises(self): + """Should raise if vocab map count doesn't match component count.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + with pytest.raises(ValueError, match="Length of vocab maps"): + NonErgodicGenerativeProcess( + components=[coin, coin], + component_weights=[0.5, 0.5], + vocab_maps=[[0, 1]], + ) + + def test_duplicate_vocab_map_entries_raise(self): + """Should raise if a component vocab map reuses a global token index.""" + coin = build_hidden_markov_model("coin", {"p": 0.5}) + + with pytest.raises(ValueError, match="must not contain duplicate"): + NonErgodicGenerativeProcess( + components=[coin], + component_weights=[1.0], + vocab_maps=[[0, 0]], + ) + class TestGenerateReturnAllStates: """Tests for generate with return_all_states=True.""" diff --git a/walkthroughs/pr-172.json b/walkthroughs/pr-172.json new file mode 100644 index 00000000..2fe60f54 --- /dev/null +++ b/walkthroughs/pr-172.json @@ -0,0 +1,157 @@ +{ + "title": "PR #172: NonErgodicGenerativeProcess & InflatedVocabularyProcess", + "description": "Walkthrough of two new generative process types: a block-diagonal nonergodic mixture model and a vocabulary inflation wrapper, plus builder functions and comprehensive tests.", + "repository": { + "remote": "https://github.com/Astera-org/simplexity.git", + "commit": "HEAD" + }, + "metadata": { + "pr": 172, + "recommendation": "approve" + }, + "steps": [ + { + "id": 1, + "title": "Overview: Two new generative process abstractions", + "body": "This PR introduces two new `GenerativeProcess` subclasses:\n\n1. **NonErgodicGenerativeProcess** — A block-diagonal mixture model that composes multiple `GenerativeProcess` components with weighted probabilities. No transitions occur between components; beliefs are updated via Bayesian filtering.\n\n2. **InflatedVocabularyProcess** — A wrapper that multiplies vocabulary size by a factor K with uniform noise, increasing optimal per-token loss by exactly `log(K)` nats.\n\nThe PR also adds 9 builder functions for constructing these processes from YAML specs, supporting disjoint and partially-overlapping vocabulary configurations with three mapping strategies (prefix, sliding, random).\n\nKey design choices:\n- Does NOT materialize a full block-diagonal matrix — stores component processes directly for efficiency\n- Uses `IndependentFactoredGenerativeProcess` for independent structures to achieve O(sum V_i) sampling complexity\n- Handles heterogeneous state types (HMM vs Factored) via flatten/pad/unflatten for `jax.lax.switch` compatibility" + }, + { + "id": 2, + "title": "NonErgodicState: the composite state representation", + "body": "The state is a `NamedTuple` with two fields:\n- `component_beliefs`: a probability distribution over components, shape `[num_components]`. During generation this becomes one-hot after the first emission; during inference it's updated via Bayes rule.\n- `component_states`: a tuple of per-component states, where each element can be either a flat `jax.Array` (HMM) or a tuple of arrays (FactoredState).\n\nThis heterogeneous state design is central to the PR — it allows mixing different process types (HMM, GHMM, Factored) in a single mixture.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:78-90" + }, + { + "id": 3, + "title": "State flattening utilities for JAX compatibility", + "body": "Since `jax.lax.switch` requires all branches to return identically-shaped arrays, these three helper functions flatten heterogeneous component states into uniform 1D arrays:\n\n- `_get_flat_size`: counts total elements\n- `_flatten_state`: concatenates to 1D\n- `_unflatten_state`: reconstructs original structure using a template\n\nNote the use of `jax.lax.dynamic_slice` instead of Python slicing in `_unflatten_state` (line 71) — this avoids `ConcretizationTypeError` inside `jax.lax.switch` since template shapes are known at trace time.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:20-75" + }, + { + "id": 4, + "title": "Constructor: vocab maps and inverse maps", + "body": "The constructor normalizes component weights, builds forward vocab maps (local-to-global), and computes inverse maps (global-to-local) for efficient observation routing.\n\nKey details:\n- Weights are normalized to sum to 1 (line 156)\n- If no vocab maps provided, each component gets identity mapping `[0..V-1]` (line 160)\n- The unified vocab size is `max(all global tokens) + 1` (line 163)\n- Inverse maps use `-1` sentinel for unmapped tokens (line 167), which `transition_states` checks to determine if an observation belongs to a component", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:123-171", + "comments": [ + { + "id": "mmtj452t4ne", + "author": "Eric Alt", + "body": "```python\nif vocab_maps is None:\n vocab_maps = [list(range(c.vocab_size)) for c in components]\nelif len(vocab_maps) != len(self.components):\n raise ValueError(\"Length of vocab maps must equal length of components.\")\n```" + }, + { + "id": "mmtjm7a5e79", + "author": "Eric Alt", + "body": "```python\n Raises:\n ValueError: If components is empty, weights don't match component count,\n or a component vocab_map contains duplicate global token indices\n\n...\n\n for i, vm in enumerate(vocab_maps):\n if len(set(vm)) != len(vm):\n raise ValueError(f\"vocab_maps[{i}] must not contain duplicate global token indices\")\n```\n\nCorresponding unit tests should also be added" + } + ] + }, + { + "id": 5, + "title": "Observation distribution: weighted mixture over components", + "body": "Computes `P(obs | state) = sum_i P(component_i | state) * P(obs | component_i, state_i)`.\n\nFor each component: gets the local distribution, scatters it into global vocab space via `vocab_maps[i]`, and weights by `component_beliefs[i]`. Tokens not in a component's vocab naturally get probability 0.\n\nThe log-space variant (line 208-228) uses `logsumexp` across stacked per-component log distributions for numerical stability, with `-inf` for unmapped tokens.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:186-228", + "comments": [ + { + "id": "mmtj5fpu16p", + "author": "Eric Alt", + "body": "`len(self.components)` should always equal `len(self.vocab_maps)` so `strict` should be `True` instead of `False`" + }, + { + "id": "mmtj8i7vy18", + "author": "Eric Alt", + "body": "```python\nglobal_dist += state.component_beliefs[i] * component_contrib\n```" + } + ] + }, + { + "id": 6, + "title": "Bayesian filtering in transition_states", + "body": "This is the core inference logic. For each observation:\n\n1. Map global token to each component's local space via inverse vocab maps (line 268)\n2. Compute likelihood `P(obs | component_i)` — 0 if token not in component's vocab (lines 271-275)\n3. Conditionally update each component's internal state only when likelihood > 0 (lines 278-284)\n4. Apply Bayes rule: `new_beliefs = beliefs * likelihoods / normalizer` (lines 288-294)\n5. Fall back to prior beliefs if all likelihoods are 0 (line 291-293)\n\nThe `jax.lax.cond` on line 278 avoids unnecessary state transitions for components that couldn't have generated the observation.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:253-299", + "comments": [ + { + "id": "mmtjcmycfty", + "author": "Eric Alt", + "body": "`strict=True`" + } + ] + }, + { + "id": 7, + "title": "Generation: sample one component, generate entire sequence", + "body": "Unlike inference (which tracks beliefs across all components), generation samples a single component at the start and generates entirely from it.\n\nThe implementation is notable for its complexity:\n- Cannot delegate to `component.generate()` because that method is also vmapped (line 363)\n- Uses flatten/pad to a common max size so `jax.lax.switch` can handle heterogeneous state types (lines 389-398)\n- `scan_step` (line 411) runs generation via `lax.switch` selecting the active component\n- Only the active component's state is updated per step (lines 424-431)\n- When `return_all_states=True`, a second inference scan reconstructs belief trajectories (lines 447-453)\n\nThis is the most intricate part of the PR — the flatten/pad/unflatten dance is the price paid for supporting mixed component types in a single JIT-compiled scan.", + "location": "simplexity/generative_processes/nonergodic_generative_process.py:347-459", + "comments": [ + { + "id": "mmtlfwdo7hn", + "author": "Eric Alt", + "body": "TODO: instead of having generative processes's `generate` function vmapped by default we should just have it function for a single sequence and define a separete `generate_batch` function that just wraps the generate function in a `vmap` (or just require the caller of generate to do that themselves) - outside the scope of this PR though" + } + ] + }, + { + "id": 8, + "title": "InflatedVocabularyProcess: controlled difficulty via noise", + "body": "A clean decorator pattern that wraps any `GenerativeProcess[State]` to inflate its vocabulary.\n\nToken encoding: `inflated_token = noise_prefix * V_base + base_token`\n- `emit_observation` samples a base token then adds a uniform random noise prefix (lines 68-71)\n- `transition_states` extracts the base token via modulo and discards noise (line 76)\n- Probability distributions are tiled K times and divided by K (line 83)\n- `probability` applies a `(1/K)^T` penalty for a sequence of length T (line 96)\n\nThis elegantly increases optimal per-token loss by exactly `log(K)` nats while preserving all state dynamics.", + "location": "simplexity/generative_processes/inflated_vocabulary_process.py:22-103" + }, + { + "id": 9, + "title": "Generator updates: slicing NonErgodicState belief trajectories", + "body": "The existing `generate_data_batch_with_full_history` function needed to handle `NonErgodicState` when slicing belief trajectories along the sequence dimension.\n\nA new `_slice_belief_states` helper (line 96) handles three state representations:\n- Plain arrays: slice directly\n- Tuples of arrays: slice each element\n- `NonErgodicState`: slice both `component_beliefs` and each entry in `component_states`, handling nested tuples for factored components (line 111)\n\nThis replaces the previous inline isinstance/tuple handling with a cleaner dispatch.", + "location": "simplexity/generative_processes/generator.py:96-118" + }, + { + "id": 10, + "title": "IndependentFactoredGenerativeProcess: noise_epsilon passthrough", + "body": "A small but important change: `noise_epsilon` is added as a constructor parameter (line 52) and forwarded to the parent `FactoredGenerativeProcess.__init__` (line 83).\n\nThis allows nonergodic processes to compose factored components that use noisy channels — previously the `IndependentFactoredGenerativeProcess` would silently ignore this parameter.", + "location": "simplexity/generative_processes/independent_factored_generative_process.py:43-84" + }, + { + "id": 11, + "title": "Builder: component factory and nonergodic process construction", + "body": "The private `_build_components_from_spec` helper (line 653) is the foundation for all nonergodic builders. It dispatches on `component_type` to build HMM, GHMM, or Factored processes.\n\n`build_nonergodic_process_from_spec` (line 700) is the main entry point, documented with a full YAML example showing how to compose HMM, GHMM, and factored components with explicit vocab maps.\n\nNote how vocab maps can be specified either per-component in the spec or globally as an override (lines 754-761).", + "location": "simplexity/generative_processes/builder.py:653-768" + }, + { + "id": 12, + "title": "Builder: vocabulary mapping strategies", + "body": "Three vocab map strategies for partially overlapping alphabets:\n\n1. **Prefix** (line 805): C0 gets `[0..V-1]`, subsequent components share a prefix of `n_shared` tokens plus unique tokens above V\n2. **Sliding** (line 820): Each component's vocab slides by `max(1, n_unique)` tokens — simple offset strategy\n3. **Random** (line 829): Each component independently samples V tokens from the global pool using a seeded RNG\n\n`build_nonergodic_partial_overlap` (line 842) orchestrates these, computing `n_shared = int(V * overlap_frac)` and `n_unique = V - n_shared` from the `overlap_frac` parameter. All components must have equal vocab size for this to work (validated on line 875-876).\n\n`build_nonergodic_disjoint_vocab` (line 771) is the simpler case: sequential non-overlapping ranges `[0..V0-1], [V0..V0+V1-1], ...`", + "location": "simplexity/generative_processes/builder.py:771-897", + "comments": [ + { + "id": "mmtncbw77fd", + "author": "Eric Alt", + "body": "```python\n def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int) -> list[list[int]]:\n \"\"\"Build vocab maps using the prefix strategy.\"\"\"\n return [list(range(v))] + [\n list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique))\n for i in range(n_components - 1)\n ]\n```" + }, + { + "id": "mmtngqqdqwf", + "author": "Eric Alt", + "body": "`assert seed is not None` is redundant with earlier check" + }, + { + "id": "mmtnjikp3yy", + "author": "Eric Alt", + "body": "```python\n def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]:\n \"\"\"Build vocab maps by having each component randomly sample V tokens from the global pool.\n\n The global vocab size is the same as in prefix mode:\n V + (n_components - 1) * n_unique.\n \"\"\"\n global_vocab_size = v + (n_components - 1) * n_unique\n rng = random.Random(seed)\n\n return [\n sorted(rng.sample(range(global_vocab_size), v))\n for _ in range(n_components)\n ]\n```" + } + ] + }, + { + "id": 13, + "title": "Builder: InflatedVocabularyProcess construction", + "body": "Two builder functions for the inflation wrapper:\n\n- `build_inflated_process` (line 900): Simple wrapper taking an existing `GenerativeProcess` and inflation factor\n- `build_inflated_process_from_spec` (line 916): Builds the base process from a spec dict first, then wraps it — supports HMM, GHMM, and factored base types\n\nBoth are thin wrappers that delegate to the `InflatedVocabularyProcess` constructor.", + "location": "simplexity/generative_processes/builder.py:900-959" + }, + { + "id": 14, + "title": "build_factored_process now returns IndependentFactoredGenerativeProcess", + "body": "A structural change in the existing `build_factored_process` function: when `structure_type == \"independent\"`, it now returns an `IndependentFactoredGenerativeProcess` directly (line 200) with early return, rather than falling through to the generic `FactoredGenerativeProcess` constructor.\n\nThis ensures that nonergodic processes composing independent factored components get the specialized subclass with per-factor sampling and frozen factor support, plus the new `noise_epsilon` passthrough.", + "location": "simplexity/generative_processes/builder.py:198-207" + }, + { + "id": 15, + "title": "Summary and recommendations", + "body": "**Strengths:**\n- Clean abstractions: both new classes implement the full `GenerativeProcess` protocol\n- The flatten/pad/unflatten approach for heterogeneous states in `jax.lax.switch` is well-documented and correct\n- Comprehensive builder functions with three vocab mapping strategies cover real research use cases\n- Excellent test coverage (21+ tests for NonErgodic, tests for Inflated, builder tests)\n- The `InflatedVocabularyProcess` is a particularly elegant design — stateless noise with provable loss increase\n\n**Architecture notes:**\n- The `generate` method in `NonErgodicGenerativeProcess` is the most complex piece — the re-implementation of the generate loop (rather than delegating to components) is necessary due to vmap constraints but adds maintenance burden\n- The `_slice_belief_states` helper in `generator.py` adds a third branch for `NonErgodicState` — if more state types emerge, this could benefit from a protocol-based dispatch\n\n**Recommendation: Approve** — Well-structured addition with solid test coverage and clear documentation of the tricky JAX compatibility patterns." + } + ] +} \ No newline at end of file