diff --git a/docs/plans/refactor-gen-process-returns.md b/docs/plans/refactor-gen-process-returns.md new file mode 100644 index 00000000..b2872e17 --- /dev/null +++ b/docs/plans/refactor-gen-process-returns.md @@ -0,0 +1,220 @@ +# Refactor: Generative Process Returns from Tuples to TypedDicts + +## Context + +Generative processes return tuples from `generate()` and `generate_data_batch()`, making call sites fragile and unclear. This refactor: +1. Moves all public return types to **TypedDicts** for type safety, extensibility, and consistency +2. Consolidates `DataBatch` / `DataBatchWithHistory` into a **single type** with all fields always present +3. Uses **PEP 695 generic TypedDict** for `GenerateResult` to preserve the `State` type parameter +4. Keeps JAX and Torch TypedDicts **separate** for clean typing + +--- + +## TypedDict Definitions + +### `generative_process.py` — `GenerateResult[State]` (generic) +```python +class GenerateResult[State](TypedDict): + states: State # final post-transition state (always populated) + observations: chex.Array # emitted tokens (always populated) + all_states: State # pre-transition history; pytree-preserving empty (batch, 0) per leaf when not requested +``` + +### `generator.py` — `DataBatch` (non-generic, State erased at this level) +```python +class DataBatch(TypedDict): + gen_states: jax.Array | tuple[jax.Array, ...] # final post-transition state + inputs: jax.Array # input tokens + labels: jax.Array # label tokens + belief_states: jax.Array | tuple[jax.Array, ...] # pre-transition history; pytree-preserving empty (batch, 0) per leaf when N/A + prefix_probabilities: jax.Array # prefix probs; (batch, 0) when N/A +``` + +### `torch_generator.py` — `TorchDataBatch` +```python +class TorchDataBatch(TypedDict): + gen_states: jax.Array | tuple[jax.Array, ...] # final state (stays JAX) + inputs: torch.Tensor # input tokens (torch) + labels: torch.Tensor # label tokens (torch) + belief_states: jax.Array | tuple[jax.Array, ...] # state history (stays JAX); pytree-preserving empty (batch, 0) per leaf when N/A + prefix_probabilities: jax.Array # prefix probs (stays JAX); (batch, 0) when N/A +``` + +### Empty Field Contract + +All empty sentinels preserve the **batch dimension** and State pytree structure. Inside `generate()` (which is vmap'd), we use `jnp.empty(0, dtype=leaf.dtype)` — vmap adds the batch dim, yielding `(batch_size, 0)` per leaf. + +| Field | When empty | Internal value (pre-vmap) | Caller sees (post-vmap) | +|-------|-----------|--------------------------|------------------------| +| `all_states` in GenerateResult | `return_all_states=False` | `jax.tree.map(lambda leaf: jnp.empty(0, dtype=leaf.dtype), state)` | Same pytree, each leaf `(batch_size, 0)` preserving original dtype | +| `belief_states` in DataBatch | basic `generate_data_batch()` | Reuses `result["all_states"]` (already batched) | Same pytree, each leaf `(batch_size, 0)` preserving original dtype | +| `prefix_probabilities` in DataBatch | basic `generate_data_batch()` | `jnp.empty((batch_size, 0), dtype=jnp.float32)` (explicit, not vmap'd) | `(batch_size, 0)` float32 | + +--- + +## File-by-file Changes + +### 1. `simplexity/generative_processes/generative_process.py` +- Add `GenerateResult[State]` TypedDict using PEP 695 syntax +- Change `generate()` return type to `GenerateResult[State]` +- `return_all_states=True` — capture the carry (currently discarded as `_`): + ```python + final_state, (all_states, obs) = jax.lax.scan(gen_states_and_obs, state, keys) + return GenerateResult(states=final_state, observations=obs, all_states=all_states) + ``` +- `return_all_states=False` — create structure-preserving empty sentinel with original dtype: + ```python + final_state, obs = jax.lax.scan(gen_obs, state, keys) + empty_states = jax.tree.map(lambda leaf: jnp.empty(0, dtype=leaf.dtype), state) + return GenerateResult(states=final_state, observations=obs, all_states=empty_states) + ``` +- Internal scan helpers keep tuple returns (scan requires it) + +### 2. `simplexity/generative_processes/independent_factored_generative_process.py` +- Import `GenerateResult` +- Same pattern as base class — `jax.tree.map` naturally handles tuple state: + - `False`: `jax.tree.map(lambda leaf: jnp.empty(0, dtype=leaf.dtype), state)` → `tuple(jnp.empty(0, dtype=leaf.dtype), ...)` + - `True`: capture carry, populate all fields + +### 3. `simplexity/generative_processes/generator.py` +- Add `DataBatch` TypedDict +- `generate_data_batch()` → returns `DataBatch`: + - Consume `generate()` via `result["states"]`, `result["observations"]` + - Reuse `result["all_states"]` as `belief_states` (already batched empty from vmap) + - Set `prefix_probabilities=jnp.empty((batch_size, 0), dtype=jnp.float32)` (explicit batch dim) +- `generate_data_batch_with_full_history()` → returns `DataBatch`: + - Gets `gen_states` from `result["states"]` (final carry) + - Gets `belief_states` from `result["all_states"]` + - Populates all 5 fields + +### 4. `simplexity/generative_processes/torch_generator.py` +- Add `TorchDataBatch` TypedDict +- `generate_data_batch()` → returns `TorchDataBatch`: + - Consume JAX `DataBatch` via dict access + - Convert inputs/labels to torch, pass through JAX state/probability fields +- `generate_data_batch_with_full_history()` → returns `TorchDataBatch`: + - Consume JAX `DataBatch`, convert inputs/labels to torch + +### 5. Test updates — call site migration + +**`tests/generative_processes/test_hidden_markov_model.py`** +- `test_single_transition` (~lines 135-162): `result = z1r.generate(...)` then `result["states"]`, `result["observations"]` +- `test_generate` (~lines 165-179): same dict access + +**`tests/generative_processes/test_generalized_hidden_markov_model.py`** +- `test_hmm_single_transition` (~lines 148-175): dict access +- `test_generate` (~lines 179-194): dict access +- `test_generate_with_intermediate_states` (~lines 198-213): use `result["all_states"]` + +**`tests/generative_processes/test_generator.py`** +- Lines 31, 52, 81: `result = generate_data_batch(...)` then `result["gen_states"]`, `result["inputs"]`, `result["labels"]` +- Full history tests: update key names if changed + +**`tests/generative_processes/test_torch_generator.py`** +- Same pattern as test_generator + +**`tests/generative_processes/test_independent_factored_generative_process.py`** +- `process.generate()` → `result["states"]`, `result["observations"]` +- `return_all_states=True` → `result["all_states"]` + +### 6. `tests/end_to_end/training.py` +- Inner `generate()` function (~line 131): consume `generate_data_batch()` via dict access, still returns `(inputs, labels)` tuple internally + ```python + def generate(step: int) -> tuple[torch.Tensor, torch.Tensor]: + result = generate_data_batch(...) + return result["inputs"], result["labels"] + ``` +- Line 238: `generate(0)[0]` still works (inner function returns tuple) +- `activation_tracker_step` (~line 189): `generate_data_batch_with_full_history()` now returns `TorchDataBatch` — access via same keys (no change needed) + +### 7. New semantic tests (add to existing test files) + +**In test_generalized_hidden_markov_model.py (or test_hidden_markov_model.py):** +- Verify `result["states"]` is the final post-transition carry when `return_all_states=True`: + ```python + result = model.generate(initial_states, keys, seq_len, True) + expected_final = eqx.filter_vmap(model.transition_states)( + result["all_states"][:, -1, :], result["observations"][:, -1] + ) + chex.assert_trees_all_close(result["states"], expected_final) + ``` +- Verify empty sentinel preserves batch dim (array-state case; tuple-state covered in factored tests below): + ```python + result = model.generate(initial_states, keys, seq_len, False) + assert result["all_states"].shape == (batch_size, 0) + ``` + +**In test_generator.py:** +- Verify `gen_states` exists and has correct shape from both `generate_data_batch()` and `generate_data_batch_with_full_history()` +- Verify empty fields from basic function preserve batch dim: + ```python + result = generate_data_batch(states, hmm, batch_size, seq_len, key) + assert result["belief_states"].shape == (batch_size, 0) + assert result["prefix_probabilities"].shape == (batch_size, 0) + ``` +- Verify `gen_states` from full history has correct shape + +**In test_independent_factored_generative_process.py:** +- Verify empty `all_states` is a tuple of empties (preserves FactoredState structure): + ```python + result = process.generate(batch_states, keys, seq_len, False) + assert isinstance(result["all_states"], tuple) + assert all(s.shape == (batch_size, 0) for s in result["all_states"]) + ``` + +--- + +## Migration Guide (Breaking Change) + +This is a breaking change for external callers. Summary of changes: + +### `GenerativeProcess.generate()` — tuple → `GenerateResult` dict +```python +# Before: +states, observations = process.generate(state, key, seq_len, False) +all_states, observations = process.generate(state, key, seq_len, True) + +# After: +result = process.generate(state, key, seq_len, False) +states = result["states"] +observations = result["observations"] + +result = process.generate(state, key, seq_len, True) +states = result["states"] # final post-transition state +observations = result["observations"] +all_states = result["all_states"] # pre-transition state history (NEW) +``` + +### `generate_data_batch()` — tuple → `DataBatch` / `TorchDataBatch` dict +```python +# Before: +gen_states, inputs, labels = generate_data_batch(...) + +# After: +result = generate_data_batch(...) +gen_states = result["gen_states"] +inputs = result["inputs"] +labels = result["labels"] +# Also available (empty when not from full_history): +# result["belief_states"], result["prefix_probabilities"] +``` + +### `generate_data_batch_with_full_history()` — now returns same `DataBatch` type +```python +# Before: +result = generate_data_batch_with_full_history(...) +belief_states = result["belief_states"] # same key +inputs = result["inputs"] # same key + +# After: same keys, plus gen_states is now also available +gen_states = result["gen_states"] # NEW: final state +``` + +--- + +## Verification + +1. `uv run --extra dev ruff check` — linting +2. `uv run --extra dev ruff format --check` — formatting +3. `uv run --extra dev --extra pytorch pyright` — type checking (generic TypedDict + all dict accesses verified) +4. `uv run --extra dev --extra pytorch pytest` — all tests pass (including new semantic tests) diff --git a/simplexity/generative_processes/generative_process.py b/simplexity/generative_processes/generative_process.py index e312591e..6fa8dc16 100644 --- a/simplexity/generative_processes/generative_process.py +++ b/simplexity/generative_processes/generative_process.py @@ -1,15 +1,24 @@ """Generative process interface.""" from abc import abstractmethod -from typing import TypeVar +from typing import TypedDict, TypeVar import chex import equinox as eqx import jax +import jax.numpy as jnp State = TypeVar("State") +class GenerateResult[State](TypedDict): + """Return payload for batched generation.""" + + states: State + observations: chex.Array + all_states: State + + class GenerativeProcess[State](eqx.Module): """A generative process is a probabilistic model that can be used to generate data.""" @@ -39,19 +48,24 @@ def transition_states(self, state: State, obs: chex.Array) -> State: @eqx.filter_vmap(in_axes=(None, 0, 0, None, None)) def generate( self, state: State, key: chex.PRNGKey, sequence_len: int, return_all_states: bool - ) -> tuple[State, chex.Array]: + ) -> GenerateResult[State]: """Generate a batch of sequences of observations from the generative process. Inputs: state: (batch_size, num_states) key: (batch_size, 2) - Returns: tuple of (belief_states, observations) where: + Returns: dict with: + states: final state after sequence generation + observations: emitted tokens + all_states: pre-transition state history if return_all_states else + a structure-preserving empty sentinel + if return_all_states is True: - belief_states is the sequence of belief states of shape: + all_states is the sequence of pre-transition states of shape: (batch_size, sequence_len, num_states) otherwise: - belief_states is the state of the final step: - (batch_size, num_states) + all_states is an empty array per state leaf of shape: + (batch_size, 0) observations is (batch_size, sequence_len) """ @@ -68,10 +82,12 @@ def gen_states_and_obs(state: State, key: chex.PRNGKey) -> tuple[State, tuple[St return new_state, (state, obs) if return_all_states: - _, (states, obs) = jax.lax.scan(gen_states_and_obs, state, keys) - return states, obs + final_state, (states, obs) = jax.lax.scan(gen_states_and_obs, state, keys) + return GenerateResult(states=final_state, observations=obs, all_states=states) - return jax.lax.scan(gen_obs, state, keys) + final_state, obs = jax.lax.scan(gen_obs, state, keys) + empty_states = jax.tree.map(lambda leaf: jnp.empty(0, dtype=leaf.dtype), state) + return GenerateResult(states=final_state, observations=obs, all_states=empty_states) @abstractmethod def observation_probability_distribution(self, state: State) -> jax.Array: diff --git a/simplexity/generative_processes/generator.py b/simplexity/generative_processes/generator.py index 15c0ddf1..b08c4be8 100644 --- a/simplexity/generative_processes/generator.py +++ b/simplexity/generative_processes/generator.py @@ -9,7 +9,7 @@ # (code quality, style, undefined names, etc.) to run normally while bypassing # the problematic imports checker that would crash during AST traversal. -from typing import Any +from typing import Any, TypedDict import equinox as eqx import jax @@ -18,6 +18,16 @@ from simplexity.generative_processes.generative_process import GenerativeProcess +class DataBatch(TypedDict): + """Unified generator payload for basic and full-history generation.""" + + gen_states: jax.Array | tuple[jax.Array, ...] + inputs: jax.Array + labels: jax.Array + belief_states: jax.Array | tuple[jax.Array, ...] + prefix_probabilities: jax.Array + + @eqx.filter_jit def generate_data_batch( gen_states: jax.Array | tuple[jax.Array, ...], @@ -27,10 +37,13 @@ def generate_data_batch( key: jax.Array, bos_token: int | None = None, eos_token: int | None = None, -) -> tuple[jax.Array | tuple[jax.Array, ...], jax.Array, jax.Array]: +) -> DataBatch: """Generate a batch of data without tracking intermediate beliefs.""" batch_keys = jax.random.split(key, batch_size) - gen_states, tokens = data_generator.generate(gen_states, batch_keys, sequence_len, False) + generate_result = data_generator.generate(gen_states, batch_keys, sequence_len, False) + tokens = generate_result["observations"] + final_states = generate_result["states"] + belief_states = generate_result["all_states"] if bos_token is not None: tokens = jnp.concatenate([jnp.full((batch_size, 1), bos_token), tokens], axis=1) @@ -39,7 +52,13 @@ def generate_data_batch( inputs = tokens[:, :-1] labels = tokens[:, 1:] - return gen_states, inputs, labels + return DataBatch( + gen_states=final_states, + inputs=inputs, + labels=labels, + belief_states=belief_states, + prefix_probabilities=jnp.empty((batch_size, 0), dtype=jnp.float32), + ) @eqx.filter_jit @@ -51,10 +70,13 @@ def generate_data_batch_with_full_history( key: jax.Array, bos_token: int | None = None, eos_token: int | None = None, -) -> dict[str, jax.Array | tuple[jax.Array, ...]]: +) -> DataBatch: """Generate sequences plus per-token belief states and prefix probabilities.""" batch_keys = jax.random.split(key, batch_size) - belief_states, tokens = data_generator.generate(gen_states, batch_keys, sequence_len, True) + generate_result = data_generator.generate(gen_states, batch_keys, sequence_len, True) + belief_states = generate_result["all_states"] + tokens = generate_result["observations"] + final_states = generate_result["states"] prefix_probs = _compute_prefix_probabilities(data_generator, gen_states, tokens) @@ -88,14 +110,13 @@ def generate_data_batch_with_full_history( else: belief_states = belief_states[:, :input_len, ...] - result = { - "belief_states": belief_states, - "prefix_probabilities": prefix_probs, - "inputs": inputs, - "labels": labels, - } - - return result + return DataBatch( + gen_states=final_states, + belief_states=belief_states, + prefix_probabilities=prefix_probs, + inputs=inputs, + labels=labels, + ) def _compute_prefix_probabilities( diff --git a/simplexity/generative_processes/independent_factored_generative_process.py b/simplexity/generative_processes/independent_factored_generative_process.py index acb83710..dd59a14c 100644 --- a/simplexity/generative_processes/independent_factored_generative_process.py +++ b/simplexity/generative_processes/independent_factored_generative_process.py @@ -14,6 +14,7 @@ FactoredGenerativeProcess, FactoredState, ) +from simplexity.generative_processes.generative_process import GenerateResult from simplexity.generative_processes.structures import ConditionalStructure from simplexity.generative_processes.structures.independent import IndependentStructure from simplexity.logger import SIMPLEXITY_LOGGER @@ -147,7 +148,7 @@ def emit_observation(self, state: FactoredState, key: jax.Array) -> jax.Array: @eqx.filter_vmap(in_axes=(None, 0, 0, None, None)) def generate( self, state: FactoredState, key: chex.PRNGKey, sequence_len: int, return_all_states: bool - ) -> tuple[FactoredState, chex.Array]: + ) -> GenerateResult[FactoredState]: """Generate sequences with frozen factor support. For frozen factors, the same key stream is used across all batch samples, @@ -161,7 +162,7 @@ def generate( return_all_states: Whether to return all intermediate states Returns: - Tuple of (final_states or all_states, observations) + Dict with final states, observations, and optional pre-transition state history """ keys = jax.random.split(key, sequence_len) frozen_keys = jax.random.split(self.frozen_key, sequence_len) if self.frozen_key is not None else keys @@ -183,7 +184,9 @@ def gen_states_and_obs( return new_state, (carry_state, obs) if return_all_states: - _, (states, obs) = jax.lax.scan(gen_states_and_obs, state, (keys, frozen_keys)) - return states, obs + final_state, (states, obs) = jax.lax.scan(gen_states_and_obs, state, (keys, frozen_keys)) + return GenerateResult(states=final_state, observations=obs, all_states=states) - return jax.lax.scan(gen_obs, state, (keys, frozen_keys)) + final_state, obs = jax.lax.scan(gen_obs, state, (keys, frozen_keys)) + empty_states = jax.tree.map(lambda leaf: jnp.empty(0, dtype=leaf.dtype), state) + return GenerateResult(states=final_state, observations=obs, all_states=empty_states) diff --git a/simplexity/generative_processes/torch_generator.py b/simplexity/generative_processes/torch_generator.py index 00f4211a..d6b8fb70 100644 --- a/simplexity/generative_processes/torch_generator.py +++ b/simplexity/generative_processes/torch_generator.py @@ -9,19 +9,30 @@ # (code quality, style, undefined names, etc.) to run normally while bypassing # the problematic imports checker that would crash during AST traversal. +from typing import TypedDict + import jax import torch from simplexity.generative_processes.generative_process import GenerativeProcess from simplexity.generative_processes.generator import ( + DataBatch, generate_data_batch as generate_jax_data_batch, -) -from simplexity.generative_processes.generator import ( generate_data_batch_with_full_history as generate_jax_data_batch_with_full_history, ) from simplexity.utils.pytorch_utils import jax_to_torch +class TorchDataBatch(TypedDict): + """Torch payload with tensor tokens and JAX states.""" + + gen_states: jax.Array | tuple[jax.Array, ...] + inputs: torch.Tensor + labels: torch.Tensor + belief_states: jax.Array | tuple[jax.Array, ...] + prefix_probabilities: jax.Array + + def generate_data_batch( gen_states: jax.Array | tuple[jax.Array, ...], data_generator: GenerativeProcess, @@ -31,7 +42,7 @@ def generate_data_batch( bos_token: int | None = None, eos_token: int | None = None, device: str | torch.device | None = None, -) -> tuple[jax.Array | tuple[jax.Array, ...], torch.Tensor, torch.Tensor]: +) -> TorchDataBatch: """Generate a batch of data. Args: @@ -45,9 +56,9 @@ def generate_data_batch( device: Optional target device for PyTorch tensors Returns: - Tuple of (generator states, inputs, labels) + Dict containing generator states, belief/prefix fields, and torch inputs/labels """ - gen_states, inputs, labels = generate_jax_data_batch( + result = generate_jax_data_batch( gen_states, data_generator, batch_size, @@ -56,7 +67,17 @@ def generate_data_batch( bos_token, eos_token, ) - return gen_states, jax_to_torch(inputs, device), jax_to_torch(labels, device) + inputs = result["inputs"] + labels = result["labels"] + assert isinstance(inputs, jax.Array) + assert isinstance(labels, jax.Array) + return TorchDataBatch( + gen_states=result["gen_states"], + belief_states=result["belief_states"], + prefix_probabilities=result["prefix_probabilities"], + inputs=jax_to_torch(inputs, device), + labels=jax_to_torch(labels, device), + ) def generate_data_batch_with_full_history( @@ -68,7 +89,7 @@ def generate_data_batch_with_full_history( bos_token: int | None = None, eos_token: int | None = None, device: str | torch.device | None = None, -) -> dict[str, jax.Array | torch.Tensor | tuple[jax.Array, ...]]: +) -> TorchDataBatch: """Generate data plus full belief/prefix histories. Args: @@ -82,13 +103,14 @@ def generate_data_batch_with_full_history( device: Optional target device for PyTorch tensors Returns: - Dict with keys: + TorchDataBatch with keys: + - gen_states: Final generator state (jax.Array or tuple[jax.Array, ...]) - belief_states: Belief states (jax.Array or tuple[jax.Array, ...]) - prefix_probabilities: Prefix probabilities (jax.Array) - inputs: Input tokens (torch.Tensor) - labels: Label tokens (torch.Tensor) """ - result = generate_jax_data_batch_with_full_history( + result: DataBatch = generate_jax_data_batch_with_full_history( gen_states, data_generator, batch_size, @@ -103,9 +125,10 @@ def generate_data_batch_with_full_history( assert isinstance(inputs, jax.Array) assert isinstance(labels, jax.Array) - return { - "belief_states": result["belief_states"], - "prefix_probabilities": result["prefix_probabilities"], - "inputs": jax_to_torch(inputs, device), - "labels": jax_to_torch(labels, device), - } + return TorchDataBatch( + gen_states=result["gen_states"], + belief_states=result["belief_states"], + prefix_probabilities=result["prefix_probabilities"], + inputs=jax_to_torch(inputs, device), + labels=jax_to_torch(labels, device), + ) diff --git a/tests/end_to_end/training.py b/tests/end_to_end/training.py index 20db7253..eac36407 100644 --- a/tests/end_to_end/training.py +++ b/tests/end_to_end/training.py @@ -130,7 +130,7 @@ def train(cfg: TrainingRunConfig, components: simplexity.Components) -> None: def generate(step: int) -> tuple[torch.Tensor, torch.Tensor]: key = jax.random.key(step) - _, inputs, labels = generate_data_batch( + result = generate_data_batch( gen_states, generative_process, cfg.training.batch_size, @@ -139,7 +139,7 @@ def generate(step: int) -> tuple[torch.Tensor, torch.Tensor]: device=device_arg, bos_token=cfg.generative_process.bos_token, ) - return inputs, labels + return result["inputs"], result["labels"] loss_fn = torch.nn.CrossEntropyLoss() diff --git a/tests/generative_processes/test_generalized_hidden_markov_model.py b/tests/generative_processes/test_generalized_hidden_markov_model.py index be7385f1..6f825078 100644 --- a/tests/generative_processes/test_generalized_hidden_markov_model.py +++ b/tests/generative_processes/test_generalized_hidden_markov_model.py @@ -145,20 +145,28 @@ def test_hmm_single_transition(z1r: GeneralizedHiddenMarkovModel): key = jax.random.PRNGKey(0)[None, :] single_transition = 1 - next_state, observation = z1r.generate(zero_state, key, single_transition, False) + result = z1r.generate(zero_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), one_state) assert observation == jnp.array(0) - next_state, observation = z1r.generate(one_state, key, single_transition, False) + result = z1r.generate(one_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), random_state) assert observation == jnp.array(1) - next_state, observation = z1r.generate(random_state, key, single_transition, False) + result = z1r.generate(random_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), zero_state) mixed_state = jnp.array([[0.4, 0.4, 0.2]]) - next_state, observation = z1r.generate(mixed_state, key, single_transition, False) + result = z1r.generate(mixed_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] # P(next=0 | obs=x) = P(prev=2 | obs=x) # P(next=1 | obs=x) = P(prev=0 | obs=x) # P(next=2 | obs=x) = P(prev=1 | obs=x) @@ -184,12 +192,17 @@ def test_generate(model_name: str, request: pytest.FixtureRequest): initial_states = jnp.repeat(model.initial_state[None, :], batch_size, axis=0) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - intermediate_states, intermediate_observations = model.generate(initial_states, keys, sequence_len, False) + result = model.generate(initial_states, keys, sequence_len, False) + intermediate_states = result["states"] + intermediate_observations = result["observations"] assert intermediate_states.shape == (batch_size, model.num_states) assert intermediate_observations.shape == (batch_size, sequence_len) + assert result["all_states"].shape == (batch_size, 0) keys = jax.random.split(jax.random.PRNGKey(1), batch_size) - final_states, final_observations = model.generate(intermediate_states, keys, sequence_len, False) + result = model.generate(intermediate_states, keys, sequence_len, False) + final_states = result["states"] + final_observations = result["observations"] assert final_states.shape == (batch_size, model.num_states) assert final_observations.shape == (batch_size, sequence_len) @@ -203,14 +216,24 @@ def test_generate_with_intermediate_states(model_name: str, request: pytest.Fixt initial_states = jnp.repeat(model.initial_state[None, :], batch_size, axis=0) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - intermediate_states, observations = model.generate(initial_states, keys, sequence_len, True) + result = model.generate(initial_states, keys, sequence_len, True) + intermediate_states = result["all_states"] + observations = result["observations"] + final_states = result["states"] assert intermediate_states.shape == (batch_size, sequence_len, model.num_states) assert observations.shape == (batch_size, sequence_len) + assert final_states.shape == (batch_size, model.num_states) + expected_final_states = eqx.filter_vmap(model.transition_states)(intermediate_states[:, -1, :], observations[:, -1]) + chex.assert_trees_all_close(final_states, expected_final_states) last_intermediate_states = intermediate_states[:, -1, :] - final_states, observations = model.generate(last_intermediate_states, keys, sequence_len, True) - assert final_states.shape == (batch_size, sequence_len, model.num_states) - assert observations.shape == (batch_size, sequence_len) + result = model.generate(last_intermediate_states, keys, sequence_len, True) + next_intermediate_states = result["all_states"] + next_observations = result["observations"] + next_final_states = result["states"] + assert next_intermediate_states.shape == (batch_size, sequence_len, model.num_states) + assert next_observations.shape == (batch_size, sequence_len) + assert next_final_states.shape == (batch_size, model.num_states) def test_hmm_observation_probability_distribution(z1r: GeneralizedHiddenMarkovModel): diff --git a/tests/generative_processes/test_generator.py b/tests/generative_processes/test_generator.py index 2da516bf..73a9bcb8 100644 --- a/tests/generative_processes/test_generator.py +++ b/tests/generative_processes/test_generator.py @@ -28,7 +28,12 @@ def test_generate_data_batch(): gen_state: jax.Array = hmm.initial_state states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) - gen_states, inputs, labels = generate_data_batch(states, hmm, batch_size, sequence_len, key) + result = generate_data_batch(states, hmm, batch_size, sequence_len, key) + gen_states = result["gen_states"] + inputs = result["inputs"] + labels = result["labels"] + belief_states = result["belief_states"] + prefix_probabilities = result["prefix_probabilities"] assert inputs.shape == (batch_size, sequence_len - 1) assert labels.shape == (batch_size, sequence_len - 1) assert jnp.all(inputs >= 0) @@ -37,7 +42,11 @@ def test_generate_data_batch(): assert jnp.all(labels < hmm.vocab_size) chex.assert_trees_all_equal(inputs[:, 1:], labels[:, :-1]) assert isinstance(gen_states, jax.Array) + assert isinstance(belief_states, jax.Array) + assert isinstance(prefix_probabilities, jax.Array) assert gen_states.shape == (batch_size, *gen_state.shape) + assert belief_states.shape == (batch_size, 0) + assert prefix_probabilities.shape == (batch_size, 0) def test_generate_data_batch_with_bos_token(): @@ -49,7 +58,7 @@ def test_generate_data_batch_with_bos_token(): states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) bos_token = hmm.vocab_size - gen_states, inputs, labels = generate_data_batch( + result = generate_data_batch( states, hmm, batch_size, @@ -57,6 +66,9 @@ def test_generate_data_batch_with_bos_token(): key, bos_token=bos_token, ) + gen_states = result["gen_states"] + inputs = result["inputs"] + labels = result["labels"] assert inputs.shape == (batch_size, sequence_len) assert labels.shape == (batch_size, sequence_len) assert jnp.all(inputs >= 0) @@ -78,7 +90,7 @@ def test_generate_data_batch_with_eos_token(): states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) eos_token = hmm.vocab_size - gen_states, inputs, labels = generate_data_batch( + result = generate_data_batch( states, hmm, batch_size, @@ -86,6 +98,9 @@ def test_generate_data_batch_with_eos_token(): key, eos_token=eos_token, ) + gen_states = result["gen_states"] + inputs = result["inputs"] + labels = result["labels"] assert inputs.shape == (batch_size, sequence_len) assert labels.shape == (batch_size, sequence_len) assert jnp.all(inputs >= 0) @@ -114,11 +129,13 @@ def test_generate_data_batch_with_full_history(): key, ) # Extract and type-check all fields + gen_states = result["gen_states"] belief_states = result["belief_states"] prefix_probs = result["prefix_probabilities"] inputs = result["inputs"] labels = result["labels"] + assert isinstance(gen_states, jax.Array) assert isinstance(belief_states, jax.Array) assert isinstance(prefix_probs, jax.Array) assert isinstance(inputs, jax.Array) @@ -128,6 +145,7 @@ def test_generate_data_batch_with_full_history(): assert belief_states.shape == (batch_size, sequence_len - 1, gen_state.shape[0]) assert prefix_probs.shape == (batch_size, inputs.shape[1]) assert labels.shape == inputs.shape + assert gen_states.shape == (batch_size, *gen_state.shape) def test_generate_data_batch_with_full_history_bos(): @@ -147,11 +165,13 @@ def test_generate_data_batch_with_full_history_bos(): key, bos_token=bos_token, ) + gen_states = result["gen_states"] belief_states = result["belief_states"] prefix_probs = result["prefix_probabilities"] inputs = result["inputs"] labels = result["labels"] + assert isinstance(gen_states, jax.Array) assert isinstance(belief_states, jax.Array) assert isinstance(prefix_probs, jax.Array) assert isinstance(inputs, jax.Array) @@ -163,5 +183,6 @@ def test_generate_data_batch_with_full_history_bos(): assert belief_states.shape == (batch_size, sequence_len, gen_state.shape[0]) assert prefix_probs.shape == (batch_size, inputs.shape[1]) assert labels.shape == inputs.shape + assert gen_states.shape == (batch_size, *gen_state.shape) # First input should be BOS token assert jnp.all(inputs[:, 0] == bos_token) diff --git a/tests/generative_processes/test_hidden_markov_model.py b/tests/generative_processes/test_hidden_markov_model.py index c36e7f09..9ab61c58 100644 --- a/tests/generative_processes/test_hidden_markov_model.py +++ b/tests/generative_processes/test_hidden_markov_model.py @@ -132,20 +132,28 @@ def test_single_transition(z1r: HiddenMarkovModel): key = jax.random.PRNGKey(0)[None, :] single_transition = 1 - next_state, observation = z1r.generate(zero_state, key, single_transition, False) + result = z1r.generate(zero_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), one_state) assert observation == jnp.array(0) - next_state, observation = z1r.generate(one_state, key, single_transition, False) + result = z1r.generate(one_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), random_state) assert observation == jnp.array(1) - next_state, observation = z1r.generate(random_state, key, single_transition, False) + result = z1r.generate(random_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] assert_proportional(probability(next_state), zero_state) mixed_state = jnp.array([[0.4, 0.4, 0.2]]) - next_state, observation = z1r.generate(mixed_state, key, single_transition, False) + result = z1r.generate(mixed_state, key, single_transition, False) + next_state = result["states"] + observation = result["observations"] # P(next=0 | obs=x) = P(prev=2 | obs=x) # P(next=1 | obs=x) = P(prev=0 | obs=x) # P(next=2 | obs=x) = P(prev=1 | obs=x) @@ -169,16 +177,43 @@ def test_generate(z1r: HiddenMarkovModel): initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - intermediate_states, intermediate_observations = z1r.generate(initial_states, keys, sequence_len, False) + result = z1r.generate(initial_states, keys, sequence_len, False) + intermediate_states = result["states"] + intermediate_observations = result["observations"] assert intermediate_states.shape == (batch_size, z1r.num_states) assert intermediate_observations.shape == (batch_size, sequence_len) + assert result["all_states"].shape == (batch_size, 0) keys = jax.random.split(jax.random.PRNGKey(1), batch_size) - final_states, final_observations = z1r.generate(intermediate_states, keys, sequence_len, False) + result = z1r.generate(intermediate_states, keys, sequence_len, False) + final_states = result["states"] + final_observations = result["observations"] assert final_states.shape == (batch_size, z1r.num_states) assert final_observations.shape == (batch_size, sequence_len) +def test_generate_with_all_states_returns_final_carry(z1r: HiddenMarkovModel): + """`states` should equal transitioning from the last intermediate state and last observation.""" + batch_size = 4 + sequence_len = 6 + initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0) + keys = jax.random.split(jax.random.PRNGKey(7), batch_size) + + result = z1r.generate(initial_states, keys, sequence_len, True) + all_states = result["all_states"] + observations = result["observations"] + final_states = result["states"] + assert all_states.shape == (batch_size, sequence_len, z1r.num_states) + assert observations.shape == (batch_size, sequence_len) + assert final_states.shape == (batch_size, z1r.num_states) + + expected_final_states = eqx.filter_vmap(z1r.transition_states)( + all_states[:, -1, :], + observations[:, -1], + ) + chex.assert_trees_all_close(final_states, expected_final_states) + + def test_observation_probability_distribution(z1r: HiddenMarkovModel): """Test probability-space observation distribution.""" state = jnp.array([0.3, 0.1, 0.6]) diff --git a/tests/generative_processes/test_independent_factored_generative_process.py b/tests/generative_processes/test_independent_factored_generative_process.py index bc1f4e41..5c8916fc 100644 --- a/tests/generative_processes/test_independent_factored_generative_process.py +++ b/tests/generative_processes/test_independent_factored_generative_process.py @@ -114,11 +114,16 @@ def test_generate_produces_correct_shapes(self, two_factor_independent_process): batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - final_states, observations = process.generate(batch_states, keys, seq_len, False) + result = process.generate(batch_states, keys, seq_len, False) + final_states = result["states"] + observations = result["observations"] + all_states = result["all_states"] assert observations.shape == (batch_size, seq_len) assert final_states[0].shape == (batch_size, 1) assert final_states[1].shape == (batch_size, 1) + assert isinstance(all_states, tuple) + assert all(state.shape == (batch_size, 0) for state in all_states) def test_generate_returns_all_states_when_requested(self, two_factor_independent_process): """generate with return_all_states=True should return state sequences.""" @@ -129,7 +134,9 @@ def test_generate_returns_all_states_when_requested(self, two_factor_independent batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - all_states, observations = process.generate(batch_states, keys, seq_len, True) + result = process.generate(batch_states, keys, seq_len, True) + all_states = result["all_states"] + observations = result["observations"] assert observations.shape == (batch_size, seq_len) assert all_states[0].shape == (batch_size, seq_len, 1) @@ -148,7 +155,7 @@ def test_frozen_factor_same_across_batch(self, three_factor_process_with_frozen) batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(123), batch_size) - _, observations = process.generate(batch_states, keys, seq_len, False) + observations = process.generate(batch_states, keys, seq_len, False)["observations"] factor_tokens = jax.vmap(process.encoder.extract_factors_vectorized)(observations) # Factor 1 (index 1) is frozen - should be identical across batch @@ -165,7 +172,7 @@ def test_unfrozen_factors_vary_across_batch(self, three_factor_process_with_froz batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(456), batch_size) - _, observations = process.generate(batch_states, keys, seq_len, False) + observations = process.generate(batch_states, keys, seq_len, False)["observations"] factor_tokens = jax.vmap(process.encoder.extract_factors_vectorized)(observations) # Factors 0 and 2 are unfrozen - should differ across batch @@ -186,12 +193,12 @@ def test_frozen_sequences_reproducible(self, three_factor_process_with_frozen): # First generation keys1 = jax.random.split(jax.random.PRNGKey(100), batch_size) - _, obs1 = process.generate(batch_states, keys1, seq_len, False) + obs1 = process.generate(batch_states, keys1, seq_len, False)["observations"] factor_tokens1 = jax.vmap(process.encoder.extract_factors_vectorized)(obs1) # Second generation with different sample keys keys2 = jax.random.split(jax.random.PRNGKey(200), batch_size) - _, obs2 = process.generate(batch_states, keys2, seq_len, False) + obs2 = process.generate(batch_states, keys2, seq_len, False)["observations"] factor_tokens2 = jax.vmap(process.encoder.extract_factors_vectorized)(obs2) # Frozen factor should be the same in both calls @@ -228,7 +235,7 @@ def test_all_factors_frozen(self): batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - _, observations = process.generate(batch_states, keys, seq_len, False) + observations = process.generate(batch_states, keys, seq_len, False)["observations"] # All batch samples should be identical for i in range(1, batch_size): @@ -243,7 +250,7 @@ def test_no_frozen_factors_matches_normal_behavior(self, two_factor_independent_ batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(0), batch_size) - _, observations = process.generate(batch_states, keys, seq_len, False) + observations = process.generate(batch_states, keys, seq_len, False)["observations"] # All tokens should be valid assert jnp.all(observations >= 0) @@ -349,7 +356,7 @@ def test_frozen_factor_states_match_across_batch(self, three_factor_process_with batch_states = tuple(jnp.tile(s[None, :], (batch_size, 1)) for s in process.initial_state) keys = jax.random.split(jax.random.PRNGKey(789), batch_size) - all_states, _ = process.generate(batch_states, keys, seq_len, True) + all_states = process.generate(batch_states, keys, seq_len, True)["all_states"] # Factor 1 states should be identical across batch frozen_factor_states = all_states[1] diff --git a/tests/generative_processes/test_torch_generator.py b/tests/generative_processes/test_torch_generator.py index ba6d4e0c..17d65b52 100644 --- a/tests/generative_processes/test_torch_generator.py +++ b/tests/generative_processes/test_torch_generator.py @@ -28,8 +28,15 @@ def test_generate_data_batch(): gen_state: jax.Array = hmm.initial_state states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) - gen_states, inputs, labels = generate_data_batch(states, hmm, batch_size, sequence_len, key) + result = generate_data_batch(states, hmm, batch_size, sequence_len, key) + gen_states = result["gen_states"] + belief_states = result["belief_states"] + prefix_probabilities = result["prefix_probabilities"] + inputs = result["inputs"] + labels = result["labels"] assert isinstance(gen_states, jax.Array) + assert isinstance(belief_states, jax.Array) + assert isinstance(prefix_probabilities, jax.Array) assert isinstance(inputs, torch.Tensor) assert isinstance(labels, torch.Tensor) assert inputs.shape == (batch_size, sequence_len - 1) @@ -40,6 +47,8 @@ def test_generate_data_batch(): assert torch.all(labels < hmm.vocab_size) assert torch.equal(inputs[:, 1:], labels[:, :-1]) assert gen_states.shape == (batch_size, *gen_state.shape) + assert belief_states.shape == (batch_size, 0) + assert prefix_probabilities.shape == (batch_size, 0) def test_generate_data_batch_with_bos_token(): @@ -51,7 +60,7 @@ def test_generate_data_batch_with_bos_token(): states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) bos_token = hmm.vocab_size - gen_states, inputs, labels = generate_data_batch( + result = generate_data_batch( states, hmm, batch_size, @@ -59,6 +68,9 @@ def test_generate_data_batch_with_bos_token(): key, bos_token=bos_token, ) + gen_states = result["gen_states"] + inputs = result["inputs"] + labels = result["labels"] assert isinstance(gen_states, jax.Array) assert isinstance(inputs, torch.Tensor) assert isinstance(labels, torch.Tensor) @@ -82,7 +94,7 @@ def test_generate_data_batch_with_eos_token(): states = jnp.repeat(gen_state[None, :], batch_size, axis=0) key = jax.random.PRNGKey(0) eos_token = hmm.vocab_size - gen_states, inputs, labels = generate_data_batch( + result = generate_data_batch( states, hmm, batch_size, @@ -90,6 +102,9 @@ def test_generate_data_batch_with_eos_token(): key, eos_token=eos_token, ) + gen_states = result["gen_states"] + inputs = result["inputs"] + labels = result["labels"] assert isinstance(gen_states, jax.Array) assert isinstance(inputs, torch.Tensor) assert isinstance(labels, torch.Tensor) @@ -120,17 +135,22 @@ def test_generate_data_batch_with_full_history(): key, ) # Extract and type-check all fields + gen_states = result["gen_states"] belief_states = result["belief_states"] prefix_probs = result["prefix_probabilities"] inputs = result["inputs"] + labels = result["labels"] + assert isinstance(gen_states, jax.Array) assert isinstance(belief_states, jax.Array) assert isinstance(prefix_probs, jax.Array) assert isinstance(inputs, torch.Tensor) + assert isinstance(labels, torch.Tensor) # Without BOS, belief_states is aligned with inputs (one less than sequence_len) assert belief_states.shape == (batch_size, sequence_len - 1, gen_state.shape[0]) assert prefix_probs.shape == (batch_size, inputs.shape[1]) + assert gen_states.shape == (batch_size, *gen_state.shape) def test_generate_data_batch_with_full_history_bos(): @@ -150,18 +170,23 @@ def test_generate_data_batch_with_full_history_bos(): key, bos_token=bos_token, ) + gen_states = result["gen_states"] belief_states = result["belief_states"] prefix_probs = result["prefix_probabilities"] inputs = result["inputs"] + labels = result["labels"] + assert isinstance(gen_states, jax.Array) assert isinstance(belief_states, jax.Array) assert isinstance(prefix_probs, jax.Array) assert isinstance(inputs, torch.Tensor) + assert isinstance(labels, torch.Tensor) # With BOS, inputs has sequence_len positions (BOS + sequence_len-1 tokens) # belief_states is aligned with inputs assert inputs.shape == (batch_size, sequence_len) assert belief_states.shape == (batch_size, sequence_len, gen_state.shape[0]) assert prefix_probs.shape == (batch_size, inputs.shape[1]) + assert gen_states.shape == (batch_size, *gen_state.shape) # First input should be BOS token assert torch.all(inputs[:, 0] == bos_token)