Skip to content

iamshouvikmitra/mcts-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mcts-diffusion

Monte Carlo Tree Search for Masked Diffusion Language Models.

Open-source implementation of the UnMaskFork algorithm from arXiv:2602.04344.

UnMaskFork applies MCTS to masked diffusion language models (LLaDA, Dream) for test-time scaling — instead of generating a single sample, it explores a tree of unmasking trajectories using deterministic action branching, achieving state-of-the-art results on code generation benchmarks.

Installation

pip install mcts-diffusion

# With model support
pip install mcts-diffusion[dream]   # Dream / Dream-Coder
pip install mcts-diffusion[llada]   # LLaDA
pip install mcts-diffusion[all]     # All models

Quick Start

from mcts_diffusion import search, MCTSConfig
from mcts_diffusion.models.dream import DreamAdapter
from mcts_diffusion.reward.code_execution import CodeExecutionReward

model = DreamAdapter("Dream-org/Dream-Coder-v0-Instruct-7B")
reward = CodeExecutionReward(test_cases=["assert add(1, 2) == 3"])

result = search(
    models=model,
    prompt="def add(a, b):",
    reward_fn=reward,
    config=MCTSConfig(nfe_budget=1024),
)
print(result.best_text)
print(f"Reward: {result.reward}, NFE used: {result.nfe_used}")

Multi-Model Search

Use multiple models in the same search tree (heterogeneous search):

from mcts_diffusion import search, MCTSConfig, ActionConfig
from mcts_diffusion.models.dream import DreamAdapter
from mcts_diffusion.models.llada import LLaDAAdapter

models = {
    "dream": DreamAdapter("Dream-org/Dream-Coder-v0-Instruct-7B"),
    "llada": LLaDAAdapter("GSAI-ML/LLaDA-8B-Instruct"),
}

actions = [
    ActionConfig(name="dream_entropy", model_name="dream", temperature=0.1, strategy="entropy"),
    ActionConfig(name="llada_lowconf", model_name="llada", temperature=0.0, strategy="low_confidence"),
]

result = search(models=models, prompt="def solve():", reward_fn=reward, actions=actions)

Custom Reward Functions

Implement the reward protocol for any evaluation:

class MyReward:
    def evaluate(self, generated_text: str) -> float:
        # Return score in [0, 1]
        return 1.0 if "correct" in generated_text else 0.0

Custom Model Adapters

Wrap any masked diffusion model:

class MyModelAdapter:
    @property
    def mask_token_id(self) -> int:
        return 99999

    @property
    def eos_token_id(self) -> int:
        return 0

    def forward_logits(self, input_ids):
        # Single forward pass -> (1, seq_len, vocab_size)
        return self.model(input_ids).logits

    def encode(self, text: str):
        return self.tokenizer(text, return_tensors="pt").input_ids

    def decode(self, token_ids):
        return self.tokenizer.decode(token_ids[0], skip_special_tokens=True)

Configuration

MCTSConfig(
    nfe_budget=12288,          # Number of forward evaluations budget
    c_exp=1.0,                 # UCT exploration coefficient
    mask_ratio_schedule=[      # Mask ratios defining tree depth
        0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2
    ],
    gen_length=768,            # Generation region length in tokens
    cache_max_mb=None,         # Max cache entries (None = unlimited)
    device="cuda",             # Device for tensors
)

Benchmarks

# HumanEval+
python benchmarks/run_humaneval.py --model Dream-org/Dream-Coder-v0-Instruct-7B --nfe-budget 12288

# MBPP+
python benchmarks/run_mbpp.py --model Dream-org/Dream-Coder-v0-Instruct-7B --nfe-budget 12288

# LiveCodeBench
python benchmarks/run_livecodebench.py --model Dream-org/Dream-Coder-v0-Instruct-7B --nfe-budget 12288

Architecture

search()
  ├── UCTSelector.select()     # Walk tree via UCT scores
  ├── Expander.expand()        # Rollout + evaluate
  │   ├── Unmasker             # Forward pass + strategy-based position selection
  │   ├── RemaskStrategy       # LowConfidence / Entropy / Random
  │   └── SearchCache          # Two-level rollout + score cache
  ├── backup()                 # Backpropagate reward to root
  └── SelectBest()             # Return highest-reward terminal

Citation

@article{unmaskfork2026,
  title={UnMaskFork: Test-Time Scaling for Masked Diffusion via Deterministic Action Branching},
  author={Thomas, Rahul and Kitanovski, Teo and Goldblum, Micah and Pal, Arka},
  journal={arXiv preprint arXiv:2602.04344},
  year={2026}
}

License

Apache-2.0

About

Monte Carlo Tree Search for Masked Diffusion Language Models. An open-source implementation of UnMaskFork (arXiv:2602.04344)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages