| Paper |
- Background
- Quickstart
- Evaluating Reasoning on the different benchmark with local RetoMaton
- Results
- Acknowledgement
- Citation
Retrieval-augmented inference methods such as kNN-LM (Khandelwal et al., 2019) improve language-model generalization by interpolating the base LM’s predictions with those from a non-parametric datastore of training examples. At test time, the model retrieves k nearest neighbors in embedding space and forms a probability distribution over their targets, modulated by a softmax temperature. This mechanism enhances factual grounding and adaptability, but it suffers from scalability and noise as datastore size increases. For more details, see the paper by Khandelwal et al., ICLR'2020
Building on this idea, RetoMaton introduces a lightweight, architecture-agnostic neuro-symbolic extension of kNN-LM. It organizes the datastore as a Weighted Finite Automaton (WFA), aligning transitions with token-level sentence flow while clustering semantically similar embeddings into connected states. This symbolic structure enables context-aware traversal across decoding steps, reducing redundant lookups and enforcing coherent memory access paths. For more details, see the paper by Alon et al., ICML'2022
Local RetoMaton extends the original framework by uncovering its inherent traceability: each decoding step corresponds to a verifiable transition within a finite-state structure, allowing reasoning paths to be explicitly reconstructed. By restricting retrieval to locally reachable transitions, it enforces symbolic coherence and bounded memory behavior, allowing transparent inspection of the model’s reasoning process. This structured retrieval not only improves efficiency and generalization but also enhances interpretability and reliability, offering state-level explainability rarely achievable in blackbox LLMs. For more details, see the paper by Mamidala et al., NeSy'2025
git clone git@github.com:TKAI-LAB-Mali/NeuroSymbolicLM.git
cd NeuroSymbolicLM
Run:
conda env create -f neurocuda.yml -n neurocuda
- Llama-2-7b: meta-llama/Llama-2-7b-hf
GSM8k: openai/gsm8kTriviaQA: mandarjoshi/trivia_qaMMLU: cais/mmluwiki: wentingzhao/knn-prompt-datastoremath: wentingzhao/math-textbooks
To save a datastore and build the Faiss index, run:
MODEL=meta-llama/Llama-3.2-1B
DSNAME=openai/gsm8k
DSCONFIG=main
python -u run_dstore.py --model_name_or_path ${MODEL} --dataset_name ${DSNAME} --dataset_config_name ${DSCONFIG} --do_eval --eval_subset validation --output_dir checkpoints/${MODEL} --dstore_dir checkpoints/${MODEL} --build_index
To build the FAISS index yourself, run:
MODEL=meta-llama/Llama-3.2-1B
DSNAME=openai/gsm8k
DSCONFIG=main
python -u run_dstore.py --model_name_or_path ${MODEL} --dataset_name ${DSNAME} --dataset_config_name ${DSCONFIG} --output_dir checkpoints/${MODEL} --dstore_dir checkpoints/${MODEL}/${DSNAME} --cluster_dstore --dstore_size 1549636 --num_clusters 15000 --sample_size 150000
Optional clustering hyperparameters are --num_clusters (typically 1/100 or 1/200 of the datastore size) and --sample_size (ideally as high as possible, but higher values consume more memory and take longer to run).
You can also directly access our built datastore and WFA clusters through the link: https://huggingface.co/datasets/Ritu27/LLamaDatastores
Downloading datastores
For example, to download the math datastore, run:
git clone https://huggingface.co/datasets/Ritu27/LLamaDatastores
cd math
git lfs install
git lfs pull
The RetoMaton framework relies on the internal beam index order to traverse its weighted automaton states accurately.
To expose this information during generation, we modified transformers/generation/utils.py (Transformers v4.48.3, Python 3.10) to include lightweight getter and setter methods immediately after the import statements, and inserted a single hook line in the _beam_search() method after beam_scorer.process() returns the sorted beams.
Add this block after the import section:
retoMaton_beam_idx = None
def get_retoMaton_beam_idx():
return retoMaton_beam_idx
def set_retoMaton_beam_idx(val):
global retoMaton_beam_idx
retoMaton_beam_idx = val.clone().detach()
Then add the line below in the _beam_search(..) method after the sorted beams are retured by the beam_scorer.process(..).
set_retoMaton_beam_idx(beam_idx)
Here’s what the modified portion of _beam_search() will look like:
To evaluate Local RetoMaton on the test set, run:
MODEL=meta-llama/Llama-3.2-1B
DSNAME=openai/gsm8k
DSCONFIG=main
python -u run_pipeline.py --model_name_or_path ${MODEL} --dataset_name ${DSNAME} --prompt --dataset_config_name ${DSCONFIG} --eval_subset test --output_dir results/${MODEL}/${DSNAME} --dstore_dir checkpoints/${MODEL}/${DSNAME} --retomaton
Additional test-time tunable hyperparameters include --lmbda1, which controls the interpolation between the datastore and the base language model; --k, specifying the number of retrieved nearest neighbors; and --knn_temp, the softmax temperature used to convert the retrieved distances into a probability distribution.
- knnlm Implementation: The knnlm is implemented based on the code available at Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval.
If you find our work helpful, please use the following citations.
@inproceedings{
mamidala2025rethinking,
title={Rethinking Reasoning in {LLM}s: Neuro-Symbolic Local RetoMaton Beyond CoT and {ICL}},
author={Rushitha Santhoshi Mamidala and Anshuman Chhabra and Ankur Mali},
booktitle={19th International Conference on Neurosymbolic Learning and Reasoning},
year={2025},
url={https://openreview.net/forum?id=ySTqCi3nqi}
}


