From 4f5702b61cc9ce7782dbb2333cb8ba8eca2970ee Mon Sep 17 00:00:00 2001 From: loren-ac Date: Tue, 24 Mar 2026 15:25:11 -0700 Subject: [PATCH] Add binary-encoded generative process wrapper Introduce BinaryEncodedProcess, which wraps any GenerativeProcess by encoding its vocabulary into fixed-width big-endian binary. The wrapper emits individual bits (vocab_size=2) and defers the underlying state transition until a complete binary word has been emitted. Incomplete sequences are supported in probability/log_probability computations. --- .../binary_encoded_process.py | 191 ++++++++++ .../test_binary_encoded_process.py | 353 ++++++++++++++++++ 2 files changed, 544 insertions(+) create mode 100644 simplexity/generative_processes/binary_encoded_process.py create mode 100644 tests/generative_processes/test_binary_encoded_process.py diff --git a/simplexity/generative_processes/binary_encoded_process.py b/simplexity/generative_processes/binary_encoded_process.py new file mode 100644 index 00000000..2f220a22 --- /dev/null +++ b/simplexity/generative_processes/binary_encoded_process.py @@ -0,0 +1,191 @@ +"""Binary-encoded generative process wrapper. + +Wraps any GenerativeProcess by encoding its vocabulary into binary. Each base token +is encoded as a fixed-width big-endian binary string of ceil(log2(vocab_size)) bits. +The wrapper emits individual bits and only transitions the underlying state after +a complete binary word has been emitted. + +Token encoding: token t is emitted as bits b_0, ..., b_{n-1} where +b_i = (t >> (n - 1 - i)) & 1 (big-endian / MSB first). +""" + +from __future__ import annotations + +import math +from typing import Any, NamedTuple + +import chex +import equinox as eqx +import jax +import jax.numpy as jnp + +from simplexity.generative_processes.generative_process import GenerativeProcess + + +class BinaryEncodedState(NamedTuple): + """State for a binary-encoded generative process. + + Attributes: + base_state: The wrapped process's belief state. + bit_position: Index of the next bit to emit (0 to num_bits - 1). + accumulated_prefix: Integer built from bits emitted so far in the current word. + """ + + base_state: Any + bit_position: jax.Array + accumulated_prefix: jax.Array + + +class BinaryEncodedProcess[State](GenerativeProcess[BinaryEncodedState]): + """Wraps a GenerativeProcess by encoding its vocabulary into binary. + + For a base process with vocab size V: + - num_bits = ceil(log2(V)) binary digits per base token + - New vocab size is 2 (binary: 0 or 1) + - Each base token is emitted as num_bits binary digits (big-endian, MSB first) + - The underlying state only transitions after all num_bits bits are emitted + - Unused binary codes (when V is not a power of 2) have probability 0 + + This encoding preserves the causal state structure of the base process. The + additional state variables (bit_position, accumulated_prefix) are deterministic + given the history and add no genuine memory. + + Args: + base_process: The generative process to wrap. Must have vocab_size >= 2. + """ + + base_process: GenerativeProcess[State] + num_bits: int + _token_indices: jax.Array + + def __init__(self, base_process: GenerativeProcess[State]) -> None: + if base_process.vocab_size < 2: + raise ValueError(f"base process vocab_size must be >= 2, got {base_process.vocab_size}") + self.base_process = base_process + self.num_bits = math.ceil(math.log2(base_process.vocab_size)) + self._token_indices = jnp.arange(2**self.num_bits) + + @property + def vocab_size(self) -> int: + """The number of observations: always 2 (binary).""" + return 2 + + @property + def initial_state(self) -> BinaryEncodedState: + """The initial state: base process initial state at bit position 0 with empty prefix.""" + return BinaryEncodedState( + base_state=self.base_process.initial_state, + bit_position=jnp.array(0, dtype=jnp.int32), + accumulated_prefix=jnp.array(0, dtype=jnp.int32), + ) + + @eqx.filter_jit + def emit_observation(self, state: BinaryEncodedState, key: chex.PRNGKey) -> chex.Array: + """Emit the next binary digit by sampling from the conditional bit distribution.""" + dist = self.observation_probability_distribution(state) + return jax.random.categorical(key, jnp.log(dist)) + + @eqx.filter_jit + def transition_states(self, state: BinaryEncodedState, obs: chex.Array) -> BinaryEncodedState: + """Update state: accumulate the observed bit, transitioning the base state when the word is complete. + + When bit_position < num_bits - 1, only the prefix and position are updated. + When bit_position == num_bits - 1, the complete token is decoded and the base + state is transitioned, then position and prefix are reset to 0. + """ + new_prefix = state.accumulated_prefix * 2 + obs + new_position = state.bit_position + 1 + is_complete = new_position >= self.num_bits + + safe_token = jnp.clip(new_prefix, 0, self.base_process.vocab_size - 1) + transitioned_base = self.base_process.transition_states(state.base_state, safe_token) + + final_base_state = jax.tree.map( + lambda new, old: jnp.where(is_complete, new, old), + transitioned_base, + state.base_state, + ) + final_position = jnp.where(is_complete, jnp.array(0, dtype=jnp.int32), new_position) + final_prefix = jnp.where(is_complete, jnp.array(0, dtype=jnp.int32), new_prefix) + + return BinaryEncodedState(final_base_state, final_position, final_prefix) + + @eqx.filter_jit + def observation_probability_distribution(self, state: BinaryEncodedState) -> jax.Array: + """Compute P(next_bit | base_state, accumulated_prefix). + + Returns a distribution over {0, 1} conditioned on the current base state + and the bits already emitted for the current token. + """ + base_dist = self.base_process.observation_probability_distribution(state.base_state) + padded_dist = jnp.zeros(2**self.num_bits).at[: self.base_process.vocab_size].set(base_dist) + + shift_amount = self.num_bits - state.bit_position - 1 + shifted = jnp.right_shift(self._token_indices, shift_amount) + + p_0 = jnp.sum(padded_dist * (shifted == state.accumulated_prefix * 2)) + p_1 = jnp.sum(padded_dist * (shifted == state.accumulated_prefix * 2 + 1)) + + total = p_0 + p_1 + return jnp.array([p_0 / total, p_1 / total]) + + @eqx.filter_jit + def log_observation_probability_distribution(self, log_belief_state: BinaryEncodedState) -> jax.Array: + """Compute log P(next_bit | log_base_state, accumulated_prefix). + + The base_state component of log_belief_state should be in log space. + """ + base_log_dist = self.base_process.log_observation_probability_distribution(log_belief_state.base_state) + padded_log_dist = jnp.full(2**self.num_bits, -jnp.inf).at[: self.base_process.vocab_size].set(base_log_dist) + + shift_amount = self.num_bits - log_belief_state.bit_position - 1 + shifted = jnp.right_shift(self._token_indices, shift_amount) + + mask_0 = jnp.where(shifted == log_belief_state.accumulated_prefix * 2, padded_log_dist, -jnp.inf) + mask_1 = jnp.where(shifted == log_belief_state.accumulated_prefix * 2 + 1, padded_log_dist, -jnp.inf) + + log_p_0 = jax.nn.logsumexp(mask_0) + log_p_1 = jax.nn.logsumexp(mask_1) + + log_total = jax.nn.logsumexp(jnp.array([log_p_0, log_p_1])) + return jnp.array([log_p_0 - log_total, log_p_1 - log_total]) + + @eqx.filter_jit + def probability(self, observations: jax.Array) -> jax.Array: + """Compute the probability of a binary sequence. + + Handles sequences of any length, including those that end mid-token. + For complete token sequences, the result equals the base process probability + of the decoded token sequence. + """ + + def _scan_fn( + state: BinaryEncodedState, bit: jax.Array + ) -> tuple[BinaryEncodedState, jax.Array]: + dist = self.observation_probability_distribution(state) + bit_prob = dist[bit] + new_state = self.transition_states(state, bit) + return new_state, bit_prob + + _, bit_probs = jax.lax.scan(_scan_fn, self.initial_state, observations) + return jnp.prod(bit_probs) + + @eqx.filter_jit + def log_probability(self, observations: jax.Array) -> jax.Array: + """Compute the log probability of a binary sequence. + + Handles sequences of any length, including those that end mid-token. + For complete token sequences, the result equals the base process log probability + of the decoded token sequence. + """ + + def _scan_fn( + state: BinaryEncodedState, bit: jax.Array + ) -> tuple[BinaryEncodedState, jax.Array]: + dist = self.observation_probability_distribution(state) + log_bit_prob = jnp.log(dist[bit]) + new_state = self.transition_states(state, bit) + return new_state, log_bit_prob + + _, log_bit_probs = jax.lax.scan(_scan_fn, self.initial_state, observations) + return jnp.sum(log_bit_probs) diff --git a/tests/generative_processes/test_binary_encoded_process.py b/tests/generative_processes/test_binary_encoded_process.py new file mode 100644 index 00000000..7c4421c3 --- /dev/null +++ b/tests/generative_processes/test_binary_encoded_process.py @@ -0,0 +1,353 @@ +"""Tests for BinaryEncodedProcess.""" + +# pylint: disable-all +# Temporarily disable all pylint checkers during AST traversal to prevent crash. +# The imports checker crashes when resolving simplexity package imports due to a bug +# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185 +# pylint: enable=all +# Re-enable all pylint checkers for the checking phase. This allows other checks +# (code quality, style, undefined names, etc.) to run normally while bypassing +# the problematic imports checker that would crash during AST traversal. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from simplexity.generative_processes.binary_encoded_process import BinaryEncodedProcess, BinaryEncodedState +from simplexity.generative_processes.builder import ( + build_generalized_hidden_markov_model, + build_hidden_markov_model, +) + + +class TestBasicProperties: + """Tests for basic properties of BinaryEncodedProcess.""" + + @pytest.fixture + def mess3_binary(self) -> BinaryEncodedProcess: + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + return BinaryEncodedProcess(mess3) + + @pytest.fixture + def coin_binary(self) -> BinaryEncodedProcess: + coin = build_hidden_markov_model("coin", {"p": 0.7}) + return BinaryEncodedProcess(coin) + + def test_vocab_size_is_two(self, mess3_binary: BinaryEncodedProcess): + assert mess3_binary.vocab_size == 2 + + def test_num_bits_mess3(self, mess3_binary: BinaryEncodedProcess): + assert mess3_binary.num_bits == 2 + + def test_num_bits_coin(self, coin_binary: BinaryEncodedProcess): + assert coin_binary.num_bits == 1 + + def test_initial_state_base_matches(self, mess3_binary: BinaryEncodedProcess): + chex.assert_trees_all_close( + mess3_binary.initial_state.base_state, + mess3_binary.base_process.initial_state, + ) + + def test_initial_state_bit_position_is_zero(self, mess3_binary: BinaryEncodedProcess): + assert mess3_binary.initial_state.bit_position == 0 + + def test_initial_state_accumulated_prefix_is_zero(self, mess3_binary: BinaryEncodedProcess): + assert mess3_binary.initial_state.accumulated_prefix == 0 + + +class TestObservationDistribution: + """Tests for observation probability distribution.""" + + @pytest.fixture + def mess3_binary(self) -> BinaryEncodedProcess: + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + return BinaryEncodedProcess(mess3) + + def test_distribution_sums_to_one(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + dist = mess3_binary.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_distribution_has_correct_size(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + dist = mess3_binary.observation_probability_distribution(state) + assert dist.shape == (2,) + + def test_first_bit_marginalizes_correctly(self, mess3_binary: BinaryEncodedProcess): + """First bit: P(0) = P(token 0) + P(token 1), P(1) = P(token 2).""" + state = mess3_binary.initial_state + base_dist = mess3_binary.base_process.observation_probability_distribution(state.base_state) + binary_dist = mess3_binary.observation_probability_distribution(state) + + expected_p0 = base_dist[0] + base_dist[1] + expected_p1 = base_dist[2] + chex.assert_trees_all_close(binary_dist[0], expected_p0, atol=1e-6) + chex.assert_trees_all_close(binary_dist[1], expected_p1, atol=1e-6) + + def test_second_bit_given_first_zero(self, mess3_binary: BinaryEncodedProcess): + """P(bit=0 | first=0) = P(token 0) / (P(token 0) + P(token 1)).""" + state = mess3_binary.initial_state + base_dist = mess3_binary.base_process.observation_probability_distribution(state.base_state) + + state_after_0 = BinaryEncodedState( + base_state=state.base_state, + bit_position=jnp.array(1, dtype=jnp.int32), + accumulated_prefix=jnp.array(0, dtype=jnp.int32), + ) + dist = mess3_binary.observation_probability_distribution(state_after_0) + + denom = base_dist[0] + base_dist[1] + chex.assert_trees_all_close(dist[0], base_dist[0] / denom, atol=1e-6) + chex.assert_trees_all_close(dist[1], base_dist[1] / denom, atol=1e-6) + + def test_second_bit_given_first_one_is_deterministic(self, mess3_binary: BinaryEncodedProcess): + """P(bit=0 | first=1) = 1.0 since code 11 is unused for mess3.""" + state = mess3_binary.initial_state + + state_after_1 = BinaryEncodedState( + base_state=state.base_state, + bit_position=jnp.array(1, dtype=jnp.int32), + accumulated_prefix=jnp.array(1, dtype=jnp.int32), + ) + dist = mess3_binary.observation_probability_distribution(state_after_1) + + chex.assert_trees_all_close(dist[0], 1.0, atol=1e-6) + chex.assert_trees_all_close(dist[1], 0.0, atol=1e-6) + + def test_distribution_sums_to_one_at_intermediate_position(self, mess3_binary: BinaryEncodedProcess): + state_after_0 = BinaryEncodedState( + base_state=mess3_binary.initial_state.base_state, + bit_position=jnp.array(1, dtype=jnp.int32), + accumulated_prefix=jnp.array(0, dtype=jnp.int32), + ) + dist = mess3_binary.observation_probability_distribution(state_after_0) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_log_distribution_consistent(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + log_state = BinaryEncodedState( + base_state=jnp.log(state.base_state), + bit_position=state.bit_position, + accumulated_prefix=state.accumulated_prefix, + ) + dist = mess3_binary.observation_probability_distribution(state) + log_dist = mess3_binary.log_observation_probability_distribution(log_state) + chex.assert_trees_all_close(log_dist, jnp.log(dist), atol=1e-5) + + +class TestTransitionStates: + """Tests for state transitions.""" + + @pytest.fixture + def mess3_binary(self) -> BinaryEncodedProcess: + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + return BinaryEncodedProcess(mess3) + + def test_base_state_unchanged_during_partial_token(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + new_state = mess3_binary.transition_states(state, jnp.array(0)) + chex.assert_trees_all_close(new_state.base_state, state.base_state, atol=1e-6) + + def test_bit_position_increments(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + new_state = mess3_binary.transition_states(state, jnp.array(0)) + assert new_state.bit_position == 1 + + def test_accumulated_prefix_updates(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + new_state = mess3_binary.transition_states(state, jnp.array(1)) + assert new_state.accumulated_prefix == 1 + + def test_base_state_transitions_after_complete_token(self, mess3_binary: BinaryEncodedProcess): + """After emitting bits 01 (token 1), base state should match base transition with token 1.""" + state = mess3_binary.initial_state + state_after_0 = mess3_binary.transition_states(state, jnp.array(0)) + state_after_01 = mess3_binary.transition_states(state_after_0, jnp.array(1)) + + expected_base = mess3_binary.base_process.transition_states(state.base_state, jnp.array(1)) + chex.assert_trees_all_close(state_after_01.base_state, expected_base, atol=1e-6) + + def test_position_resets_after_complete_token(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + state_after_0 = mess3_binary.transition_states(state, jnp.array(0)) + state_after_01 = mess3_binary.transition_states(state_after_0, jnp.array(1)) + assert state_after_01.bit_position == 0 + + def test_prefix_resets_after_complete_token(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + state_after_0 = mess3_binary.transition_states(state, jnp.array(0)) + state_after_01 = mess3_binary.transition_states(state_after_0, jnp.array(1)) + assert state_after_01.accumulated_prefix == 0 + + def test_token_0_transition(self, mess3_binary: BinaryEncodedProcess): + """Bits 00 should decode to token 0.""" + state = mess3_binary.initial_state + state_after_0 = mess3_binary.transition_states(state, jnp.array(0)) + state_after_00 = mess3_binary.transition_states(state_after_0, jnp.array(0)) + + expected = mess3_binary.base_process.transition_states(state.base_state, jnp.array(0)) + chex.assert_trees_all_close(state_after_00.base_state, expected, atol=1e-6) + + def test_token_2_transition(self, mess3_binary: BinaryEncodedProcess): + """Bits 10 should decode to token 2.""" + state = mess3_binary.initial_state + state_after_1 = mess3_binary.transition_states(state, jnp.array(1)) + state_after_10 = mess3_binary.transition_states(state_after_1, jnp.array(0)) + + expected = mess3_binary.base_process.transition_states(state.base_state, jnp.array(2)) + chex.assert_trees_all_close(state_after_10.base_state, expected, atol=1e-6) + + +class TestProbability: + """Tests for sequence probability computation.""" + + @pytest.fixture + def mess3(self): + return build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + + @pytest.fixture + def mess3_binary(self, mess3) -> BinaryEncodedProcess: + return BinaryEncodedProcess(mess3) + + def test_complete_sequence_matches_base(self, mess3, mess3_binary: BinaryEncodedProcess): + """Binary sequence probability equals base sequence probability for complete tokens.""" + base_seq = jnp.array([0, 1, 2, 0]) + binary_seq = jnp.array([0, 0, 0, 1, 1, 0, 0, 0]) + + base_prob = mess3.probability(base_seq) + binary_prob = mess3_binary.probability(binary_seq) + chex.assert_trees_all_close(binary_prob, base_prob, atol=1e-6) + + def test_incomplete_sequence_returns_valid_probability(self, mess3_binary: BinaryEncodedProcess): + binary_seq = jnp.array([0, 0, 0]) + prob = mess3_binary.probability(binary_seq) + assert prob > 0 + assert prob <= 1 + + def test_incomplete_sequence_extends_complete(self, mess3_binary: BinaryEncodedProcess): + """P(b0, b1, b2) = P(b0, b1) * P(b2 | b0, b1) and both should be valid.""" + complete = jnp.array([0, 0]) + incomplete = jnp.array([0, 0, 0]) + p_complete = mess3_binary.probability(complete) + p_incomplete = mess3_binary.probability(incomplete) + assert p_incomplete <= p_complete + + def test_log_probability_consistent(self, mess3_binary: BinaryEncodedProcess): + binary_seq = jnp.array([0, 1, 1, 0, 0, 0]) + prob = mess3_binary.probability(binary_seq) + log_prob = mess3_binary.log_probability(binary_seq) + chex.assert_trees_all_close(log_prob, jnp.log(prob), atol=1e-5) + + def test_multiple_complete_sequences(self, mess3, mess3_binary: BinaryEncodedProcess): + """Verify several different token sequences.""" + for base_tokens in [[0], [1], [2], [2, 1, 0], [1, 1, 1]]: + base_seq = jnp.array(base_tokens) + bits = [] + for t in base_tokens: + bits.extend([(t >> 1) & 1, t & 1]) + binary_seq = jnp.array(bits) + + base_prob = mess3.probability(base_seq) + binary_prob = mess3_binary.probability(binary_seq) + chex.assert_trees_all_close(binary_prob, base_prob, atol=1e-6) + + +class TestGeneration: + """Tests for sequence generation.""" + + @pytest.fixture + def mess3_binary(self) -> BinaryEncodedProcess: + mess3 = build_hidden_markov_model("mess3", {"x": 0.15, "a": 0.6}) + return BinaryEncodedProcess(mess3) + + def test_generate_valid_tokens(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + batch_state = jax.tree.map(lambda x: jnp.broadcast_to(x, (4,) + x.shape), state) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + _, observations = mess3_binary.generate(batch_state, keys, 20, False) + + assert observations.shape == (4, 20) + assert jnp.all(observations >= 0) + assert jnp.all(observations < 2) + + def test_generate_with_return_all_states(self, mess3_binary: BinaryEncodedProcess): + state = mess3_binary.initial_state + batch_state = jax.tree.map(lambda x: jnp.broadcast_to(x, (4,) + x.shape), state) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + states, observations = mess3_binary.generate(batch_state, keys, 10, True) + + assert observations.shape == (4, 10) + assert states.base_state.shape == (4, 10) + mess3_binary.base_process.initial_state.shape + + def test_decoded_tokens_are_valid(self, mess3_binary: BinaryEncodedProcess): + """All decoded binary pairs should map to valid tokens (0, 1, or 2).""" + state = mess3_binary.initial_state + batch_state = jax.tree.map(lambda x: jnp.broadcast_to(x, (50,) + x.shape), state) + keys = jax.random.split(jax.random.PRNGKey(42), 50) + _, observations = mess3_binary.generate(batch_state, keys, 200, False) + + pairs = observations.reshape(50, 100, 2) + decoded = pairs[:, :, 0] * 2 + pairs[:, :, 1] + assert jnp.all(decoded < 3) + + def test_decoded_generation_matches_base_distribution(self, mess3_binary: BinaryEncodedProcess): + """Decoded token frequencies should approximate the stationary distribution.""" + state = mess3_binary.initial_state + batch_state = jax.tree.map(lambda x: jnp.broadcast_to(x, (200,) + x.shape), state) + keys = jax.random.split(jax.random.PRNGKey(42), 200) + _, observations = mess3_binary.generate(batch_state, keys, 200, False) + + pairs = observations.reshape(200, 100, 2) + decoded = pairs[:, :, 0] * 2 + pairs[:, :, 1] + + for token in range(3): + freq = jnp.mean(decoded == token) + chex.assert_trees_all_close(freq, 1.0 / 3.0, atol=0.05) + + +class TestWithDifferentBaseProcesses: + """Tests for wrapping different process types.""" + + def test_coin_has_one_bit(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + binary = BinaryEncodedProcess(coin) + assert binary.num_bits == 1 + + def test_coin_binary_encoding_is_identity(self): + """With vocab_size=2, binary encoding should reproduce the base distribution exactly.""" + coin = build_hidden_markov_model("coin", {"p": 0.7}) + binary = BinaryEncodedProcess(coin) + + state = binary.initial_state + dist = binary.observation_probability_distribution(state) + base_dist = coin.observation_probability_distribution(coin.initial_state) + chex.assert_trees_all_close(dist, base_dist, atol=1e-6) + + def test_coin_probability_matches_base(self): + coin = build_hidden_markov_model("coin", {"p": 0.7}) + binary = BinaryEncodedProcess(coin) + + seq = jnp.array([0, 1, 0, 0, 1]) + chex.assert_trees_all_close(binary.probability(seq), coin.probability(seq), atol=1e-6) + + def test_wrap_ghmm(self): + ghmm = build_generalized_hidden_markov_model("tom_quantum", {"alpha": 1.0, "beta": 1.0}) + binary = BinaryEncodedProcess(ghmm) + assert binary.vocab_size == 2 + + state = binary.initial_state + dist = binary.observation_probability_distribution(state) + chex.assert_trees_all_close(jnp.sum(dist), 1.0, atol=1e-6) + + def test_ghmm_generation(self): + ghmm = build_generalized_hidden_markov_model("tom_quantum", {"alpha": 1.0, "beta": 1.0}) + binary = BinaryEncodedProcess(ghmm) + + state = binary.initial_state + batch_state = jax.tree.map(lambda x: jnp.broadcast_to(x, (4,) + x.shape), state) + keys = jax.random.split(jax.random.PRNGKey(0), 4) + _, observations = binary.generate(batch_state, keys, 10, False) + + assert jnp.all(observations >= 0) + assert jnp.all(observations < 2)