Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dda3538
Add SNN work and ignore local notebook/test artifacts
bmdillon Apr 27, 2026
65c817a
jupyter notebook demo of snn functionality
bmdillon Apr 27, 2026
f779515
updates to readme
bmdillon Apr 27, 2026
4c6b30a
Update README.md
bmdillon Apr 27, 2026
4151231
add learnable beta and threhold functionality
bmdillon May 4, 2026
9afeb79
update docs
bmdillon May 4, 2026
2385b90
readout and reset updated
bmdillon May 5, 2026
f10193f
update branch for pr
bmdillon May 5, 2026
eeae0de
docs update
bmdillon May 5, 2026
84dcf54
added comments
bmdillon May 5, 2026
e80ea92
notebook update
bmdillon May 5, 2026
2490bd8
pr bug fix
bmdillon May 8, 2026
e6a47ed
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] May 8, 2026
7e8e1bc
remove notebook
bmdillon May 17, 2026
10a52bb
refactor readout
bmdillon May 17, 2026
21c9c93
update to SNNReadout attributes
bmdillon May 18, 2026
6aaf6a0
Added backend support to docs
bmdillon May 18, 2026
18def3c
updated vivado->vitis in tests
bmdillon May 18, 2026
fbbc6f8
update docs on RF support for spiking neurons
bmdillon May 18, 2026
eee087d
snn streaming moved to nnet_snn_stream.h
bmdillon May 18, 2026
55dcbbc
updated snn readout layer defaults
bmdillon May 18, 2026
a493457
moved gettattr
bmdillon May 18, 2026
3640948
snn window parsing moved to optimizer pass
bmdillon May 18, 2026
d1d21fd
reverted changes
bmdillon May 18, 2026
542d6db
snnreadout moved from utils to contrib
bmdillon May 18, 2026
f69c429
updated mentions of tutorial notebook
bmdillon May 18, 2026
17a4523
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] May 18, 2026
98020e4
add __contains__ to handle scalar spiking neuron attributes (threshol…
bmdillon May 19, 2026
4d98241
test updates for snn
bmdillon May 19, 2026
c4683c0
updated snnreadout defaults
bmdillon May 19, 2026
879ceaa
update config conversion
bmdillon May 19, 2026
1f6f498
Merge branch 'main' into hls4ml-pr
bmdillon May 21, 2026
49e0e31
layer attribute updates for keras and onnx
bmdillon May 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ docs/_build
docs/autodoc/*
hls4mlprj_*
*~
*.ipynb
*.ipynb_checkpoints/
*.bak
201 changes: 201 additions & 0 deletions docs/advanced/snn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
========================================
Spiking Neural Networks (PyTorch/SNN)
========================================

This page describes the initial SNN support in the PyTorch frontend.
Comment thread
bmdillon marked this conversation as resolved.

Install the SNN frontend dependencies with:

.. code-block:: bash

pip install hls4ml[snn]

Backend support
===============

The SNN flow currently supports only the ``Vitis`` backend.

Execution model
===============

Current hls4ml SNN implementations are synchronous (clock-driven). Neuron state
updates and layer computations run in standard HLS pipelines/streams each cycle
according to interface handshakes. The generated design is not a native
asynchronous/event-routed neuromorphic architecture (yet!).
Comment thread
bmdillon marked this conversation as resolved.

Reuse factor support
====================

Standard hls4ml layers used inside an SNN, such as ``Dense``/linear layers,
retain their normal ``ReuseFactor`` support. ``ReuseFactor`` can still be set at
the model, layer type, or layer name level for these layers, and each dense layer
uses its own configured value independently of the surrounding spiking neuron
layers. The spiking neuron kernels themselves, ``IFNeuron`` and ``LIFNeuron``, do not
currently expose ``ReuseFactor``. They process one timestep at a time, keep
internal membrane state across timesteps, and unroll the per-neuron update loop
across ``n_out`` channels.

Supported PyTorch modules and readout wrappers
==============================================

The frontend currently supports direct parsing of:

* ``Leaky`` -> ``LIFNeuron`` (or ``IFNeuron`` when ``beta`` is effectively 1)

``SNNReadout`` is an hls4ml layer, not a ``snntorch`` module. To use the
built-in hls4ml readout from a PyTorch model, instantiate the provided PyTorch
marker module:

.. code-block:: python

from hls4ml.contrib.snntorch import SNNReadout

The marker is an identity in PyTorch and is converted to the hls4ml
``SNNReadout`` layer by the PyTorch frontend.

`snntorch` tracing
==================

``snntorch`` modules are treated as leaf modules by the hls4ml PyTorch FX tracer.
This allows conversion models to use ``snntorch.Leaky`` directly without defining
conversion-only wrapper classes.

For ``Leaky``, the supported reset mechanisms are:

* ``subtract``
* ``zero``

``threshold`` supports scalar or per-neuron vectors (length ``n_out``) for both ``IFNeuron`` and ``LIFNeuron``.
``beta`` supports scalar or per-neuron vectors for ``LIFNeuron``.

Conversion selects the most memory-efficient representation automatically:

* scalar values are emitted as compile-time constants
* per-neuron values are emitted as parameter vectors

For trainable snntorch parameters, conversion uses the current parameter values from the model
at conversion time.

Readout and Decision Rules
==========================

The hls4ml ``SNNReadout`` layer implements programmable per-model decision policies.
By default, ``output_mode="spike"`` preserves the original spike-count behavior:

* ``argmax_spike_count``
* ``first_to_threshold``
* ``threshold_then_argmax``
* ``binary_logit`` (for binary classifiers with ``n_classes == 2``)

The layer accumulates class spikes over a window. For most decision rules it emits
a class ID. For ``binary_logit``, it emits a score equal to
``count(class_1) - count(class_0)``.

For non-spiking readout heads, set ``output_mode="membrane"`` and connect
``SNNReadout`` directly after the final dense/linear layer instead of after a
final spiking neuron. In this mode the readout owns the final membrane state:

.. code-block:: python

x = self.fc2(x)
return self.readout(x)

At each timestep, the generated readout computes:

.. code-block:: cpp

mem[i] = beta * mem[i] + input[i];

No threshold or reset-on-spike is applied in membrane mode. The supported
membrane decision policies are:

* ``argmax_membrane``
* ``binary_logit`` (emits ``mem(class_1) - mem(class_0)`` for binary classifiers)

This will be explained in a tutorial in the hls4ml-tutorials repo.

Do not place a final spiking neuron before ``SNNReadout(output_mode="membrane")``
unless you intentionally want the readout to consume that neuron's spike output.
The membrane mode does not recover or expose the internal membrane state of a
preceding ``Leaky``/``IFNeuron``/``LIFNeuron`` layer. If a final output neuron
has a learnable ``beta``, that learnable neuron membrane is not the same state
as the readout-owned membrane. The readout uses its own scalar ``beta``.

When using the default PyTorch parser, the wrapper module should expose these
attributes as needed:

* ``n_classes`` (defaults to the input feature count if omitted)
* ``window_size`` or ``stream_length`` (defaults to ``1``)
* ``class_threshold`` (defaults to ``1``)
* ``output_mode`` (defaults to ``spike``; use ``membrane`` for readout-owned membrane accumulation)
* ``beta`` (defaults to ``1.0`` for membrane readout)
* ``decision_rule`` (defaults to ``argmax_spike_count``)
* ``reset_policy`` or ``state_reset_policy`` (defaults to ``fixed_window``)

Window Boundary Semantics
=========================

The current implementation uses ``window_size`` timesteps as the sequence boundary
for generated HLS. During PyTorch conversion, the first fixed-window
``SNNReadout``'s ``window_size`` is propagated to all converted ``IFNeuron`` and
``LIFNeuron`` layers in the graph.

At each boundary:

* the class decision is emitted
* internal readout counters or readout membrane state are reset for the next sequence
* internal ``IFNeuron``/``LIFNeuron`` membrane state is reset for the next sequence

The reset happens after the final timestep has been processed and has contributed
to the output. This behavior is compatible with fixed-length time windows.

Only fixed-window reset is implemented in generated layer kernels today.
``state_reset_policy`` accepts future-facing values such as ``tlast``,
``host_pulse``, and ``never``, but the current layer kernels still use fixed
``window_size`` reset behavior.

Running ``hls_model.predict()``
==============================

Check warning on line 158 in docs/advanced/snn.rst

View workflow job for this annotation

GitHub Actions / build

Title underline too short.

Check warning on line 158 in docs/advanced/snn.rst

View workflow job for this annotation

GitHub Actions / build

Title underline too short.

Compiled SNN models are stateful across top-function calls. For fixed-window
SNN inference, call the compiled model once per timestep and pass exactly
``window_size`` timesteps for each independent sequence:

.. code-block:: python

last = None
for step in range(timesteps):
x_step = x_sequence[step].astype("float32")[None, :]
last = hls_model.predict(x_step)

After the last call in the window, generated HLS resets the neuron and readout
state for the next sequence. Avoid making stray single-timestep ``predict``
calls before evaluating a sequence, because those calls advance the state.

For membrane readout, the PyTorch reference should match the generated readout
accumulation:

.. code-block:: python

mem = torch.zeros_like(currents[:, 0, :])
for step in range(currents.shape[1]):
mem = beta * mem + currents[:, step, :]
pred = mem.argmax(dim=1)

Using only the final dense current, or using spike-count reduction for a
membrane readout, does not match generated HLS behavior.

Precision note
==============

Membrane readout accumulates dense currents over the full window, so very narrow
fixed-point types can reduce accuracy even when the floating-point PyTorch model
looks good.

``TLAST`` note
==============

True AXI sideband ``TLAST`` boundary handling requires top-level writer/interface support for packetized AXI stream types.
The current implementation does not yet expose ``TLAST`` to layer kernels directly.

For variable-length windows, a practical workaround is to keep the hls4ml core unchanged and perform ``TLAST`` to boundary conversion in a thin wrapper IP around the generated project.
3 changes: 3 additions & 0 deletions docs/frontend/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ of the model. If the ``io_parallel`` I/O type (see :ref:`Concepts`) is used, a t
Outputs are not transposed back by default, but in ``io_parallel`` case, a transpose node can be added. If not needed, these adjustments can also be switched off. See :py:class:`~hls4ml.utils.config.config_from_pytorch_model` for details.

The equivalent of Keras extension API is not yet available for PyTorch parser, and will be provided in the future.

.. note::
Experimental spiking layer support is available for selected modules. See :doc:`../advanced/snn` for details.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
advanced/precision
advanced/fifo_depth
advanced/extension
advanced/snn
advanced/model_optimization
advanced/bramfactor
advanced/plugins
Expand Down
128 changes: 128 additions & 0 deletions hls4ml/backends/vivado/passes/snn_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import IFNeuron, LIFNeuron, SNNReadout

if_config_template = """struct config{index} : nnet::if_neuron_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned io_type = nnet::{iotype};
static const unsigned window_size = {window_size};
static const bool threshold_is_vector = {threshold_is_vector};
static constexpr float threshold = {threshold};
static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism};
typedef {threshold_t.name} threshold_t;
typedef {membrane_t.name} membrane_t;
}};\n"""

if_function_template = 'nnet::if_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {threshold});'
snn_include_list = ['nnet_utils/nnet_snn.h', 'nnet_utils/nnet_snn_stream.h']


class IFNeuronConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(IFNeuron)
self.template = if_config_template

def format(self, node):
params = self._default_config_params(node)
params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false'
return self.template.format(**params)


class IFNeuronFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(IFNeuron, include_header=snn_include_list)
self.template = if_function_template

def format(self, node):
params = self._default_function_params(node)
params['threshold'] = (
node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr'
)
return self.template.format(**params)


lif_config_template = """struct config{index} : nnet::lif_neuron_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned io_type = nnet::{iotype};
static const unsigned window_size = {window_size};
static const bool beta_is_vector = {beta_is_vector};
static const bool threshold_is_vector = {threshold_is_vector};
static constexpr float threshold = {threshold};
static constexpr float beta = {beta};
static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism};
typedef {beta_t.name} beta_t;
typedef {threshold_t.name} threshold_t;
typedef {membrane_t.name} membrane_t;
}};\n"""

lif_function_template = 'nnet::lif_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {beta}, {threshold});'


class LIFNeuronConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(LIFNeuron)
self.template = lif_config_template

def format(self, node):
params = self._default_config_params(node)
params['beta_is_vector'] = 'true' if node.get_attr('beta_mode', 'scalar') == 'vector' else 'false'
params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false'
return self.template.format(**params)


class LIFNeuronFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(LIFNeuron, include_header=snn_include_list)
self.template = lif_function_template

def format(self, node):
params = self._default_function_params(node)
params['beta'] = node.get_weights('beta_vec').name if node.get_attr('beta_mode', 'scalar') == 'vector' else 'nullptr'
params['threshold'] = (
node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr'
)
return self.template.format(**params)


readout_config_template = """struct config{index} : nnet::snn_readout_config {{
static const unsigned n_classes = {n_classes};
static const unsigned io_type = nnet::{iotype};
static const unsigned window_size = {window_size};
static const unsigned class_threshold = {class_threshold};
static constexpr float beta = {beta};
static const nnet::snn_readout_mode output_mode = nnet::snn_readout_mode::{output_mode};
static const nnet::snn_decision_rule decision_rule = nnet::snn_decision_rule::{decision_rule};
typedef {membrane_t.name} membrane_t;
}};\n"""

readout_function_template = 'nnet::snn_readout<{input_t}, {output_t}, {config}>({input}, {output});'


class SNNReadoutConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(SNNReadout)
self.template = readout_config_template

def format(self, node):
params = self._default_config_params(node)
return self.template.format(**params)


class SNNReadoutFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(SNNReadout, include_header=snn_include_list)
self.template = readout_function_template

def format(self, node):
params = self._default_function_params(node)
return self.template.format(**params)


def register_snn_templates(backend):
backend.register_template(IFNeuronConfigTemplate)
backend.register_template(IFNeuronFunctionTemplate)
backend.register_template(LIFNeuronConfigTemplate)
backend.register_template(LIFNeuronFunctionTemplate)
backend.register_template(SNNReadoutConfigTemplate)
backend.register_template(SNNReadoutFunctionTemplate)
Loading
Loading