Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3091ac5
Factored processes -> Dev (#131)
casperlchristensen Dec 9, 2025
2dd3e8c
Fix resolve base config bug (#134)
ealt Dec 10, 2025
21efdbb
Update github workflows for dev branch (#133)
ealt Dec 11, 2025
9514ce5
Activation visualizations -> dev (#132)
casperlchristensen Dec 16, 2025
e6c2a6c
Add subspace orthogonality analysis for factored processes (#136)
loren-ac Dec 16, 2025
44f53ee
Expose ability to compute subspace orthogonality in LinearRegressionA…
ealt Dec 17, 2025
6681393
Add CONTRIBUTING.md with PR requirements for dev and main (#138)
loren-ac Dec 17, 2025
7a62e72
Fix/dropdown slider interaction (#143)
casperlchristensen Dec 17, 2025
919adee
Automatically save log files at the end of managed runs (#142)
ealt Dec 17, 2025
0b1b6a2
Add simplexity-multirun CLI for parallel experiment execution (#144)
adamimos Dec 19, 2025
67ff8b0
save more path-specific visualizations (#145)
casperlchristensen Dec 19, 2025
9549554
Casper/generic resolution (#161)
casperlchristensen Jan 7, 2026
861bfbb
Improve metric naming for length and readability (#153)
loren-ac Jan 7, 2026
fb28491
reduce number of metrics returned from variance analysis (#162)
casperlchristensen Jan 7, 2026
4050f50
return targets (#163)
casperlchristensen Jan 7, 2026
1c06eaa
Extend format_layer_spec to handle all TransformerLens layer patterns…
loren-ac Jan 13, 2026
83a9528
option for no deduplication (#166)
casperlchristensen Jan 13, 2026
d4f6021
Add IndependentFactoredGenerativeProcess for frozen factor support (#…
loren-ac Jan 15, 2026
c982f44
noises process option (#165)
casperlchristensen Jan 16, 2026
cb7b034
Add NonErgodicGenerativeProcess for block diagonal mixture models
kylejray Feb 4, 2026
b4abda3
Fix off-by-one in NonErgodic generate inference pass
Feb 13, 2026
e8918ed
Add InflatedVocabularyProcess for vocabulary inflation with uniform n…
kylejray Feb 17, 2026
4436413
Add disjoint/partial-overlap vocab builders and use IndependentFactor…
Feb 24, 2026
e13bd71
Fix random vocab mode to independently sample tokens per component
Feb 25, 2026
23925a1
Address PR #172 review feedback
Feb 26, 2026
f089ed1
Apply ruff formatting
Feb 26, 2026
a7c5b3d
Fix pyright errors: widen generator state types for NonErgodicState
Feb 26, 2026
4f50965
Revert generator signature widening, use type: ignore in tests
Feb 26, 2026
2520382
Remove misleading auto-infer comment from YAML config
Feb 26, 2026
3a8e550
Merge remote-tracking branch 'origin' into kyle/nonergodic
casperlchristensen Mar 3, 2026
e67ab10
re-delete visualization
casperlchristensen Mar 4, 2026
18b3662
Merge branch 'main' into kyle/nonergodic
casperlchristensen Mar 4, 2026
368a1f4
re-simplify docstrings
casperlchristensen Mar 4, 2026
2c97431
Merge branch 'kyle/nonergodic' of https://github.com/Astera-org/simpl…
casperlchristensen Mar 4, 2026
d5413a1
remove dependency
casperlchristensen Mar 4, 2026
70140d7
Merge branch 'main' into kyle/nonergodic
casperlchristensen Mar 11, 2026
80ebd19
Refactor vocab map handling in NonErgodicGenerativeProcess and improv…
ealt Mar 16, 2026
2d6a356
ruff format
ealt Mar 16, 2026
447d03d
Add error handling for random mode in build_nonergodic_partial_overla…
ealt Mar 16, 2026
f4f1e23
Merge pull request #178 from Astera-org/review/nonergodic-pr172
kylejray Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 318 additions & 1 deletion simplexity/generative_processes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# the problematic imports checker that would crash during AST traversal.

import inspect
import random
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal

Expand All @@ -18,7 +19,11 @@

from simplexity.generative_processes.factored_generative_process import ComponentType, FactoredGenerativeProcess
from simplexity.generative_processes.generalized_hidden_markov_model import GeneralizedHiddenMarkovModel
from simplexity.generative_processes.generative_process import GenerativeProcess
from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel
from simplexity.generative_processes.independent_factored_generative_process import IndependentFactoredGenerativeProcess
from simplexity.generative_processes.inflated_vocabulary_process import InflatedVocabularyProcess
from simplexity.generative_processes.nonergodic_generative_process import NonErgodicGenerativeProcess
from simplexity.generative_processes.structures import (
ConditionalTransitions,
FullyConditional,
Expand Down Expand Up @@ -192,7 +197,16 @@ def build_factored_process(

if structure_type == "independent":
structure = IndependentStructure()
elif structure_type == "chain":
return IndependentFactoredGenerativeProcess(
component_types=component_types,
transition_matrices=transition_matrices,
normalizing_eigenvectors=normalizing_eigenvectors,
initial_states=initial_states,
structure=structure,
noise_epsilon=noise_epsilon,
)

if structure_type == "chain":
if "control_maps" not in structure_kwargs:
raise ValueError("Missing required argument 'control_maps' for chain structure")
structure = SequentialConditional(control_maps=tuple(structure_kwargs["control_maps"]), vocab_sizes=vocab_sizes)
Expand Down Expand Up @@ -634,3 +648,306 @@ def build_transition_coupled_from_spec(
emission_variant_indices_array,
emission_control_maps_arrays,
)


def _build_components_from_spec(
components: Sequence[dict[str, Any]],
device: str | None = None,
) -> list[GenerativeProcess]:
"""Build component GenerativeProcess instances from specifications.

Args:
components: List of component specs. Each spec has:
- component_type: "hmm", "ghmm", or "factored"
- For hmm/ghmm: process_name, process_params
- For factored: structure_type, spec, and structure-specific params
device: Device placement.

Returns:
List of built GenerativeProcess instances.

Raises:
ValueError: If component_type is unknown.
"""
built_components = []

for comp_spec in components:
comp_type = comp_spec.get("component_type", "hmm")

if comp_type == "hmm":
process: GenerativeProcess = build_hidden_markov_model(
process_name=comp_spec["process_name"],
process_params=comp_spec.get("process_params", {}),
device=device,
)
elif comp_type == "ghmm":
process = build_generalized_hidden_markov_model(
process_name=comp_spec["process_name"],
process_params=comp_spec.get("process_params", {}),
device=device,
)
elif comp_type == "factored":
factored_kwargs = {k: v for k, v in comp_spec.items() if k not in ("component_type", "vocab_map")}
process = build_factored_process_from_spec(**factored_kwargs)
else:
raise ValueError(f"Unknown component_type: {comp_type}")

built_components.append(process)

return built_components


def build_nonergodic_process_from_spec(
components: Sequence[dict[str, Any]],
component_weights: Sequence[float],
vocab_maps: Sequence[Sequence[int]] | None = None,
device: str | None = None,
) -> NonErgodicGenerativeProcess:
"""Build a nonergodic process from component specifications.

Creates a NonErgodicGenerativeProcess that composes multiple GenerativeProcess
instances into a truly nonergodic mixture with block diagonal structure.

Args:
components: List of component specs. Each spec has:
- component_type: "hmm", "ghmm", or "factored"
- For hmm/ghmm: process_name, process_params
- For factored: structure_type, spec, and structure-specific params
- vocab_map: Optional per-component vocab mapping
component_weights: Mixture weights for components (will be normalized).
vocab_maps: Optional global vocab maps (overrides per-component).
device: Device placement.

Returns:
NonErgodicGenerativeProcess instance.

Example:
```yaml
instance:
_target_: simplexity.generative_processes.builder.build_nonergodic_process_from_spec
components:
- component_type: hmm
process_name: mess3
process_params: {x: 0.15, a: 0.6}
- component_type: ghmm
process_name: tom_quantum
process_params: {alpha: 1.0, beta: 4.0}
- component_type: factored
structure_type: independent
spec:
- component_type: hmm
variants:
- process_name: coin
process_params: {p: 0.5}
component_weights: [0.5, 0.3, 0.2]
vocab_maps:
- [0, 1, 2]
- [0, 1, 2]
- [0, 1]
```

Raises:
ValueError: If component_type is unknown.
"""
built_components = _build_components_from_spec(components, device=device)

if vocab_maps is None:
inferred_vocab_maps = []
for comp_spec, process in zip(components, built_components, strict=True):
comp_vocab_map = comp_spec.get("vocab_map", list(range(process.vocab_size)))
inferred_vocab_maps.append(comp_vocab_map)
final_vocab_maps: Sequence[Sequence[int]] = inferred_vocab_maps
else:
final_vocab_maps = vocab_maps

return NonErgodicGenerativeProcess(
components=built_components,
component_weights=component_weights,
vocab_maps=final_vocab_maps,
device=device,
)


def build_nonergodic_disjoint_vocab(
components: Sequence[dict[str, Any]],
component_weights: Sequence[float],
device: str | None = None,
) -> NonErgodicGenerativeProcess:
"""Build a nonergodic process where each component has a fully disjoint alphabet.

Builds each component once to discover its vocab_size, then assigns
non-overlapping vocab_maps: C0 -> [0..V0-1], C1 -> [V0..V0+V1-1], etc.

Args:
components: List of component specs (same format as build_nonergodic_process_from_spec).
component_weights: Mixture weights for components.
device: Device placement.

Returns:
NonErgodicGenerativeProcess with disjoint per-component vocabularies.
"""
built_components = _build_components_from_spec(components, device=device)

vocab_maps: list[list[int]] = []
offset = 0
for c in built_components:
vocab_maps.append(list(range(offset, offset + c.vocab_size)))
offset += c.vocab_size

return NonErgodicGenerativeProcess(
components=built_components,
component_weights=component_weights,
vocab_maps=vocab_maps,
device=device,
)


def _build_prefix_vocab_maps(n_components: int, v: int, n_shared: int, n_unique: int) -> list[list[int]]:
"""Build vocab maps using the prefix strategy.

C0 gets [0..V-1]. Ci>0 gets shared [0..n_shared-1] + unique tokens above V.
"""
return [list(range(v))] + [
list(range(n_shared)) + list(range(v + i * n_unique, v + (i + 1) * n_unique)) for i in range(n_components - 1)
]


def _build_sliding_vocab_maps(n_components: int, v: int, n_unique: int) -> list[list[int]]:
"""Build vocab maps using the sliding/offset strategy.

Ci gets [i*offset..i*offset+V-1] where offset = max(1, n_unique).
"""
offset = max(1, n_unique)
return [list(range(i * offset, i * offset + v)) for i in range(n_components)]


def _build_random_vocab_maps(n_components: int, v: int, n_unique: int, seed: int) -> list[list[int]]:
"""Build vocab maps by having each component randomly sample V tokens from the global pool.

The global vocab size is the same as in prefix mode:
V + (n_components - 1) * n_unique.
"""
global_vocab_size = v + (n_components - 1) * n_unique
rng = random.Random(seed)
return [sorted(rng.sample(range(global_vocab_size), v)) for _ in range(n_components)]


def build_nonergodic_partial_overlap(
components: Sequence[dict[str, Any]],
component_weights: Sequence[float],
overlap_frac: float = 0.7,
mode: Literal["prefix", "sliding", "random"] = "prefix",
seed: int | None = None,
device: str | None = None,
) -> NonErgodicGenerativeProcess:
"""Build a nonergodic process with partially overlapping alphabets.

Args:
components: List of component specs (same format as build_nonergodic_process_from_spec).
component_weights: Mixture weights for components.
overlap_frac: Fraction of tokens shared between components (0.0 = disjoint, 1.0 = full overlap).
mode: Strategy for assigning vocab maps:
- "prefix": C0 gets [0..V-1], Ci>0 gets shared prefix + unique suffix above V.
- "sliding": Each component's vocab is offset by V * (1 - overlap_frac) from the previous.
- "random": Each component independently samples V tokens from the global pool.
Global pool size matches prefix mode. Requires the ``seed`` parameter.
seed: Random seed for reproducibility. Required when mode="random".
device: Device placement.

Returns:
NonErgodicGenerativeProcess with partially overlapping vocabularies.

Raises:
ValueError: If mode is unknown or seed is missing for random mode.
"""
if mode == "random" and seed is None:
raise ValueError("seed is required when mode='random'")

built_components = _build_components_from_spec(components, device=device)
comp_vocab_sizes = [c.vocab_size for c in built_components]
if len(set(comp_vocab_sizes)) != 1:
raise ValueError(f"All components must have equal vocab_size for partial_overlap, got {comp_vocab_sizes}")
v = comp_vocab_sizes[0]
n_shared = int(v * overlap_frac)
n_unique = v - n_shared
n_components = len(components)

if mode == "prefix":
vocab_maps = _build_prefix_vocab_maps(n_components, v, n_shared, n_unique)
elif mode == "sliding":
vocab_maps = _build_sliding_vocab_maps(n_components, v, n_unique)
elif mode == "random":
if seed is None:
raise ValueError("seed is required when mode='random'")
vocab_maps = _build_random_vocab_maps(n_components, v, n_unique, seed)
else:
raise ValueError(f"Unknown mode '{mode}'. Must be 'prefix', 'sliding', or 'random'.")

return NonErgodicGenerativeProcess(
components=built_components,
component_weights=component_weights,
vocab_maps=vocab_maps,
device=device,
)


def build_inflated_process(
base_process: GenerativeProcess,
inflation_factor: int,
) -> InflatedVocabularyProcess:
"""Build an inflated vocabulary process wrapping a base process.

Args:
base_process: Any GenerativeProcess to wrap.
inflation_factor: Number of noise variants per base token (K >= 2).

Returns:
InflatedVocabularyProcess with vocab_size = K * base_process.vocab_size.
"""
return InflatedVocabularyProcess(base_process, inflation_factor)


def build_inflated_process_from_spec(
base_spec: dict[str, Any],
inflation_factor: int,
device: str | None = None,
) -> InflatedVocabularyProcess:
"""Build an inflated vocabulary process from a base process specification.

Args:
base_spec: Specification for the base process. Must include:
- component_type: "hmm", "ghmm", or "factored"
- For hmm/ghmm: process_name, process_params
- For factored: structure_type, spec, and structure-specific params
inflation_factor: Number of noise variants per base token (K >= 2).
device: Device placement.

Returns:
InflatedVocabularyProcess wrapping the built base process.

Raises:
ValueError: If component_type is unknown.
"""
comp_type = base_spec.get("component_type", "hmm")

if comp_type == "hmm":
base_process: GenerativeProcess = build_hidden_markov_model(
process_name=base_spec["process_name"],
process_params=base_spec.get("process_params", {}),
device=device,
noise_epsilon=base_spec.get("noise_epsilon", 0.0),
)
elif comp_type == "ghmm":
base_process = build_generalized_hidden_markov_model(
process_name=base_spec["process_name"],
process_params=base_spec.get("process_params", {}),
device=device,
noise_epsilon=base_spec.get("noise_epsilon", 0.0),
)
elif comp_type == "factored":
factored_kwargs = {k: v for k, v in base_spec.items() if k not in ("component_type",)}
base_process = build_factored_process_from_spec(**factored_kwargs)
else:
raise ValueError(f"Unknown base component_type: {comp_type}")

return InflatedVocabularyProcess(base_process, inflation_factor)
Loading
Loading