From 4f567c7f950e93fe08fbd8510ac6d1c17fd8f0bf Mon Sep 17 00:00:00 2001 From: loren-ac Date: Mon, 23 Mar 2026 18:21:11 -0700 Subject: [PATCH] Add shuriken process: 3-state binary nonunifilar HMM Implements the shuriken process as a new generative process with parameters p, r, u, v. Includes minimality determinant test and default Hydra config. Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 1 + .../transition_matrices.py | 46 +++++++++++++++++++ .../configs/generative_process/shuriken.yaml | 15 ++++++ .../test_transition_matrices.py | 46 +++++++++++++++++++ 4 files changed, 108 insertions(+) create mode 100644 tests/end_to_end/configs/generative_process/shuriken.yaml diff --git a/.gitignore b/.gitignore index 494ef98f..721c0ceb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ data/ multirun/ outputs/ +references/ config.ini *temp* # MLflow tracking data diff --git a/simplexity/generative_processes/transition_matrices.py b/simplexity/generative_processes/transition_matrices.py index d71aa2ed..4667067d 100644 --- a/simplexity/generative_processes/transition_matrices.py +++ b/simplexity/generative_processes/transition_matrices.py @@ -342,6 +342,51 @@ def tom_quantum(alpha: float, beta: float) -> jax.Array: return transition_matrices +def shuriken(p: float = 0.72, r: float = 0.24, u: float = 0.36, v: float = 0.52) -> jax.Array: + """Creates transition matrices for the Shuriken Process. + + A parameterized family of 3-state binary edge-emitting nonunifilar HMMs. + + States: A, B, C + Alphabet: {0, 1} + + Symbol-labeled transition matrices where T[x, i, j] = P(next_state=j, emit x | current_state=i): + + T0 = [[u, p-u, 0 ], T1 = [[0, 0, 1-p], + [0, v, p-v], [1-p, 0, 0 ], + [r, 0, 0 ]] [0, 1-r, 0 ]] + + The generator is minimal (all 3 hidden states are needed) when the minimality determinant + det(M) = -(p - r)^2 * (p - v) is nonzero, i.e. when p != r and p != v. Keeping these + differences large also improves numerical conditioning. The suggested parameter constraints + 0 < r < p < 1, 0 < u < p, 0 < v < p + ensure minimality and that all matrix entries are non-negative. + + Args: + p: P(emit 0) from states A and B. Also 1 - P(emit 1) from those states. + r: P(emit 0) from state C. Must differ from p for minimality. + u: P(emit 0, stay in A | state A). Controls the A/B split when emitting 0 from A. + v: P(emit 0, stay in B | state B). Must differ from p for minimality. + + Returns: + Transition matrices of shape (2, 3, 3). + """ + return jnp.array( + [ + [ + [u, p - u, 0], + [0, v, p - v], + [r, 0, 0], + ], + [ + [0, 0, 1 - p], + [1 - p, 0, 0], + [0, 1 - r, 0], + ], + ] + ) + + def zero_one_random(p: float) -> jax.Array: """Creates a transition matrix for the Zero One Random (Z1R) Process. @@ -375,6 +420,7 @@ def zero_one_random(p: float) -> jax.Array: "mr_name": mr_name, "no_consecutive_ones": no_consecutive_ones, "rrxor": rrxor, + "shuriken": shuriken, "sns": sns, "zero_one_random": zero_one_random, } diff --git a/tests/end_to_end/configs/generative_process/shuriken.yaml b/tests/end_to_end/configs/generative_process/shuriken.yaml new file mode 100644 index 00000000..31e70034 --- /dev/null +++ b/tests/end_to_end/configs/generative_process/shuriken.yaml @@ -0,0 +1,15 @@ +name: shuriken +instance: + _target_: simplexity.generative_processes.builder.build_hidden_markov_model + process_name: shuriken + process_params: + p: 0.72 + r: 0.24 + u: 0.36 + v: 0.52 + device: ${device} + +base_vocab_size: ??? +bos_token: ??? +eos_token: null +vocab_size: ??? diff --git a/tests/generative_processes/test_transition_matrices.py b/tests/generative_processes/test_transition_matrices.py index 60a0d2b0..13ab39dc 100644 --- a/tests/generative_processes/test_transition_matrices.py +++ b/tests/generative_processes/test_transition_matrices.py @@ -17,6 +17,7 @@ no_consecutive_ones, post_quantum, rrxor, + shuriken, sns, tom_quantum, zero_one_random, @@ -204,6 +205,51 @@ def test_rrxor(): assert jnp.allclose(stationary_distribution, jnp.array([2, 1, 1, 1, 1]) / 6) +def test_shuriken(): + """Test the shuriken transition matrices.""" + transition_matrices = shuriken() + assert transition_matrices.shape == (2, 3, 3) + validate_hmm_transition_matrices(transition_matrices) + + +def test_shuriken_custom_params(): + """Test the shuriken transition matrices with custom parameters.""" + transition_matrices = shuriken(p=0.8, r=0.3, u=0.4, v=0.5) + assert transition_matrices.shape == (2, 3, 3) + validate_hmm_transition_matrices(transition_matrices) + + +def test_shuriken_minimality_determinant(): + """Test that the minimality determinant matches the closed-form expression. + + For pure states A=[1,0,0], B=[0,1,0], C=[0,0,1], construct + M = [[1, P(0|A), P(00|A)], + [1, P(0|B), P(00|B)], + [1, P(0|C), P(00|C)]] + and verify |det(M)| = (p-r)^2 * (p-v). + + The nonzero determinant confirms that the three pure-state predictive distributions + are linearly independent, establishing that the model is minimal (3 states are needed). + """ + p, r, u, v = 0.72, 0.24, 0.36, 0.52 + transition_matrices = shuriken(p=p, r=r, u=u, v=v) + t0 = transition_matrices[0] + + pure_states = jnp.eye(3) + p0 = jnp.sum(pure_states @ t0, axis=-1) + next_states = pure_states @ t0 + next_states_normalized = next_states / p0[:, None] + p00 = p0 * jnp.sum(next_states_normalized @ t0, axis=-1) + + m = jnp.stack([jnp.ones(3), p0, p00], axis=-1) + actual_abs_det = jnp.abs(jnp.linalg.det(m)) + expected_abs_det = (p - r) ** 2 * (p - v) + + assert jnp.isclose(actual_abs_det, expected_abs_det, atol=1e-6), ( + f"Minimality determinant |det(M)|={actual_abs_det} != expected {expected_abs_det}" + ) + + def test_sns(): """Test the sns transition matrices.""" transition_matrices = sns(p=0.5, q=0.5)