-
Notifications
You must be signed in to change notification settings - Fork 554
Basic SNN functionality in hls4ml (mapped from snntorch) #1470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
dda3538
Add SNN work and ignore local notebook/test artifacts
bmdillon 65c817a
jupyter notebook demo of snn functionality
bmdillon f779515
updates to readme
bmdillon 4c6b30a
Update README.md
bmdillon 4151231
add learnable beta and threhold functionality
bmdillon 9afeb79
update docs
bmdillon 2385b90
readout and reset updated
bmdillon f10193f
update branch for pr
bmdillon eeae0de
docs update
bmdillon 84dcf54
added comments
bmdillon e80ea92
notebook update
bmdillon 2490bd8
pr bug fix
bmdillon e6a47ed
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] 7e8e1bc
remove notebook
bmdillon 10a52bb
refactor readout
bmdillon 21c9c93
update to SNNReadout attributes
bmdillon 6aaf6a0
Added backend support to docs
bmdillon 18def3c
updated vivado->vitis in tests
bmdillon fbbc6f8
update docs on RF support for spiking neurons
bmdillon eee087d
snn streaming moved to nnet_snn_stream.h
bmdillon 55dcbbc
updated snn readout layer defaults
bmdillon a493457
moved gettattr
bmdillon 3640948
snn window parsing moved to optimizer pass
bmdillon d1d21fd
reverted changes
bmdillon 542d6db
snnreadout moved from utils to contrib
bmdillon f69c429
updated mentions of tutorial notebook
bmdillon 17a4523
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] 98020e4
add __contains__ to handle scalar spiking neuron attributes (threshol…
bmdillon 4d98241
test updates for snn
bmdillon c4683c0
updated snnreadout defaults
bmdillon 879ceaa
update config conversion
bmdillon 1f6f498
Merge branch 'main' into hls4ml-pr
bmdillon 49e0e31
layer attribute updates for keras and onnx
bmdillon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,5 +13,6 @@ docs/_build | |
| docs/autodoc/* | ||
| hls4mlprj_* | ||
| *~ | ||
| *.ipynb | ||
| *.ipynb_checkpoints/ | ||
| *.bak | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| ======================================== | ||
| Spiking Neural Networks (PyTorch/SNN) | ||
| ======================================== | ||
|
|
||
| This page describes the initial SNN support in the PyTorch frontend. | ||
|
|
||
| Install the SNN frontend dependencies with: | ||
|
|
||
| .. code-block:: bash | ||
|
|
||
| pip install hls4ml[snn] | ||
|
|
||
| Backend support | ||
| =============== | ||
|
|
||
| The SNN flow currently supports only the ``Vitis`` backend. | ||
|
|
||
| Execution model | ||
| =============== | ||
|
|
||
| Current hls4ml SNN implementations are synchronous (clock-driven). Neuron state | ||
| updates and layer computations run in standard HLS pipelines/streams each cycle | ||
| according to interface handshakes. The generated design is not a native | ||
| asynchronous/event-routed neuromorphic architecture (yet!). | ||
|
bmdillon marked this conversation as resolved.
|
||
|
|
||
| Reuse factor support | ||
| ==================== | ||
|
|
||
| Standard hls4ml layers used inside an SNN, such as ``Dense``/linear layers, | ||
| retain their normal ``ReuseFactor`` support. ``ReuseFactor`` can still be set at | ||
| the model, layer type, or layer name level for these layers, and each dense layer | ||
| uses its own configured value independently of the surrounding spiking neuron | ||
| layers. The spiking neuron kernels themselves, ``IFNeuron`` and ``LIFNeuron``, do not | ||
| currently expose ``ReuseFactor``. They process one timestep at a time, keep | ||
| internal membrane state across timesteps, and unroll the per-neuron update loop | ||
| across ``n_out`` channels. | ||
|
|
||
| Supported PyTorch modules and readout wrappers | ||
| ============================================== | ||
|
|
||
| The frontend currently supports direct parsing of: | ||
|
|
||
| * ``Leaky`` -> ``LIFNeuron`` (or ``IFNeuron`` when ``beta`` is effectively 1) | ||
|
|
||
| ``SNNReadout`` is an hls4ml layer, not a ``snntorch`` module. To use the | ||
| built-in hls4ml readout from a PyTorch model, instantiate the provided PyTorch | ||
| marker module: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| from hls4ml.contrib.snntorch import SNNReadout | ||
|
|
||
| The marker is an identity in PyTorch and is converted to the hls4ml | ||
| ``SNNReadout`` layer by the PyTorch frontend. | ||
|
|
||
| `snntorch` tracing | ||
| ================== | ||
|
|
||
| ``snntorch`` modules are treated as leaf modules by the hls4ml PyTorch FX tracer. | ||
| This allows conversion models to use ``snntorch.Leaky`` directly without defining | ||
| conversion-only wrapper classes. | ||
|
|
||
| For ``Leaky``, the supported reset mechanisms are: | ||
|
|
||
| * ``subtract`` | ||
| * ``zero`` | ||
|
|
||
| ``threshold`` supports scalar or per-neuron vectors (length ``n_out``) for both ``IFNeuron`` and ``LIFNeuron``. | ||
| ``beta`` supports scalar or per-neuron vectors for ``LIFNeuron``. | ||
|
|
||
| Conversion selects the most memory-efficient representation automatically: | ||
|
|
||
| * scalar values are emitted as compile-time constants | ||
| * per-neuron values are emitted as parameter vectors | ||
|
|
||
| For trainable snntorch parameters, conversion uses the current parameter values from the model | ||
| at conversion time. | ||
|
|
||
| Readout and Decision Rules | ||
| ========================== | ||
|
|
||
| The hls4ml ``SNNReadout`` layer implements programmable per-model decision policies. | ||
| By default, ``output_mode="spike"`` preserves the original spike-count behavior: | ||
|
|
||
| * ``argmax_spike_count`` | ||
| * ``first_to_threshold`` | ||
| * ``threshold_then_argmax`` | ||
| * ``binary_logit`` (for binary classifiers with ``n_classes == 2``) | ||
|
|
||
| The layer accumulates class spikes over a window. For most decision rules it emits | ||
| a class ID. For ``binary_logit``, it emits a score equal to | ||
| ``count(class_1) - count(class_0)``. | ||
|
|
||
| For non-spiking readout heads, set ``output_mode="membrane"`` and connect | ||
| ``SNNReadout`` directly after the final dense/linear layer instead of after a | ||
| final spiking neuron. In this mode the readout owns the final membrane state: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| x = self.fc2(x) | ||
| return self.readout(x) | ||
|
|
||
| At each timestep, the generated readout computes: | ||
|
|
||
| .. code-block:: cpp | ||
|
|
||
| mem[i] = beta * mem[i] + input[i]; | ||
|
|
||
| No threshold or reset-on-spike is applied in membrane mode. The supported | ||
| membrane decision policies are: | ||
|
|
||
| * ``argmax_membrane`` | ||
| * ``binary_logit`` (emits ``mem(class_1) - mem(class_0)`` for binary classifiers) | ||
|
|
||
| This will be explained in a tutorial in the hls4ml-tutorials repo. | ||
|
|
||
| Do not place a final spiking neuron before ``SNNReadout(output_mode="membrane")`` | ||
| unless you intentionally want the readout to consume that neuron's spike output. | ||
| The membrane mode does not recover or expose the internal membrane state of a | ||
| preceding ``Leaky``/``IFNeuron``/``LIFNeuron`` layer. If a final output neuron | ||
| has a learnable ``beta``, that learnable neuron membrane is not the same state | ||
| as the readout-owned membrane. The readout uses its own scalar ``beta``. | ||
|
|
||
| When using the default PyTorch parser, the wrapper module should expose these | ||
| attributes as needed: | ||
|
|
||
| * ``n_classes`` (defaults to the input feature count if omitted) | ||
| * ``window_size`` or ``stream_length`` (defaults to ``1``) | ||
| * ``class_threshold`` (defaults to ``1``) | ||
| * ``output_mode`` (defaults to ``spike``; use ``membrane`` for readout-owned membrane accumulation) | ||
| * ``beta`` (defaults to ``1.0`` for membrane readout) | ||
| * ``decision_rule`` (defaults to ``argmax_spike_count``) | ||
| * ``reset_policy`` or ``state_reset_policy`` (defaults to ``fixed_window``) | ||
|
|
||
| Window Boundary Semantics | ||
| ========================= | ||
|
|
||
| The current implementation uses ``window_size`` timesteps as the sequence boundary | ||
| for generated HLS. During PyTorch conversion, the first fixed-window | ||
| ``SNNReadout``'s ``window_size`` is propagated to all converted ``IFNeuron`` and | ||
| ``LIFNeuron`` layers in the graph. | ||
|
|
||
| At each boundary: | ||
|
|
||
| * the class decision is emitted | ||
| * internal readout counters or readout membrane state are reset for the next sequence | ||
| * internal ``IFNeuron``/``LIFNeuron`` membrane state is reset for the next sequence | ||
|
|
||
| The reset happens after the final timestep has been processed and has contributed | ||
| to the output. This behavior is compatible with fixed-length time windows. | ||
|
|
||
| Only fixed-window reset is implemented in generated layer kernels today. | ||
| ``state_reset_policy`` accepts future-facing values such as ``tlast``, | ||
| ``host_pulse``, and ``never``, but the current layer kernels still use fixed | ||
| ``window_size`` reset behavior. | ||
|
|
||
| Running ``hls_model.predict()`` | ||
| ============================== | ||
|
Check warning on line 158 in docs/advanced/snn.rst
|
||
|
|
||
| Compiled SNN models are stateful across top-function calls. For fixed-window | ||
| SNN inference, call the compiled model once per timestep and pass exactly | ||
| ``window_size`` timesteps for each independent sequence: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| last = None | ||
| for step in range(timesteps): | ||
| x_step = x_sequence[step].astype("float32")[None, :] | ||
| last = hls_model.predict(x_step) | ||
|
|
||
| After the last call in the window, generated HLS resets the neuron and readout | ||
| state for the next sequence. Avoid making stray single-timestep ``predict`` | ||
| calls before evaluating a sequence, because those calls advance the state. | ||
|
|
||
| For membrane readout, the PyTorch reference should match the generated readout | ||
| accumulation: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| mem = torch.zeros_like(currents[:, 0, :]) | ||
| for step in range(currents.shape[1]): | ||
| mem = beta * mem + currents[:, step, :] | ||
| pred = mem.argmax(dim=1) | ||
|
|
||
| Using only the final dense current, or using spike-count reduction for a | ||
| membrane readout, does not match generated HLS behavior. | ||
|
|
||
| Precision note | ||
| ============== | ||
|
|
||
| Membrane readout accumulates dense currents over the full window, so very narrow | ||
| fixed-point types can reduce accuracy even when the floating-point PyTorch model | ||
| looks good. | ||
|
|
||
| ``TLAST`` note | ||
| ============== | ||
|
|
||
| True AXI sideband ``TLAST`` boundary handling requires top-level writer/interface support for packetized AXI stream types. | ||
| The current implementation does not yet expose ``TLAST`` to layer kernels directly. | ||
|
|
||
| For variable-length windows, a practical workaround is to keep the hls4ml core unchanged and perform ``TLAST`` to boundary conversion in a thin wrapper IP around the generated project. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate | ||
| from hls4ml.model.layers import IFNeuron, LIFNeuron, SNNReadout | ||
|
|
||
| if_config_template = """struct config{index} : nnet::if_neuron_config {{ | ||
| static const unsigned n_in = {n_in}; | ||
| static const unsigned n_out = {n_out}; | ||
| static const unsigned io_type = nnet::{iotype}; | ||
| static const unsigned window_size = {window_size}; | ||
| static const bool threshold_is_vector = {threshold_is_vector}; | ||
| static constexpr float threshold = {threshold}; | ||
| static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism}; | ||
| typedef {threshold_t.name} threshold_t; | ||
| typedef {membrane_t.name} membrane_t; | ||
| }};\n""" | ||
|
|
||
| if_function_template = 'nnet::if_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {threshold});' | ||
| snn_include_list = ['nnet_utils/nnet_snn.h', 'nnet_utils/nnet_snn_stream.h'] | ||
|
|
||
|
|
||
| class IFNeuronConfigTemplate(LayerConfigTemplate): | ||
| def __init__(self): | ||
| super().__init__(IFNeuron) | ||
| self.template = if_config_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_config_params(node) | ||
| params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false' | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| class IFNeuronFunctionTemplate(FunctionCallTemplate): | ||
| def __init__(self): | ||
| super().__init__(IFNeuron, include_header=snn_include_list) | ||
| self.template = if_function_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_function_params(node) | ||
| params['threshold'] = ( | ||
| node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr' | ||
| ) | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| lif_config_template = """struct config{index} : nnet::lif_neuron_config {{ | ||
| static const unsigned n_in = {n_in}; | ||
| static const unsigned n_out = {n_out}; | ||
| static const unsigned io_type = nnet::{iotype}; | ||
| static const unsigned window_size = {window_size}; | ||
| static const bool beta_is_vector = {beta_is_vector}; | ||
| static const bool threshold_is_vector = {threshold_is_vector}; | ||
| static constexpr float threshold = {threshold}; | ||
| static constexpr float beta = {beta}; | ||
| static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism}; | ||
| typedef {beta_t.name} beta_t; | ||
| typedef {threshold_t.name} threshold_t; | ||
| typedef {membrane_t.name} membrane_t; | ||
| }};\n""" | ||
|
|
||
| lif_function_template = 'nnet::lif_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {beta}, {threshold});' | ||
|
|
||
|
|
||
| class LIFNeuronConfigTemplate(LayerConfigTemplate): | ||
| def __init__(self): | ||
| super().__init__(LIFNeuron) | ||
| self.template = lif_config_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_config_params(node) | ||
| params['beta_is_vector'] = 'true' if node.get_attr('beta_mode', 'scalar') == 'vector' else 'false' | ||
| params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false' | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| class LIFNeuronFunctionTemplate(FunctionCallTemplate): | ||
| def __init__(self): | ||
| super().__init__(LIFNeuron, include_header=snn_include_list) | ||
| self.template = lif_function_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_function_params(node) | ||
| params['beta'] = node.get_weights('beta_vec').name if node.get_attr('beta_mode', 'scalar') == 'vector' else 'nullptr' | ||
| params['threshold'] = ( | ||
| node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr' | ||
| ) | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| readout_config_template = """struct config{index} : nnet::snn_readout_config {{ | ||
| static const unsigned n_classes = {n_classes}; | ||
| static const unsigned io_type = nnet::{iotype}; | ||
| static const unsigned window_size = {window_size}; | ||
| static const unsigned class_threshold = {class_threshold}; | ||
| static constexpr float beta = {beta}; | ||
| static const nnet::snn_readout_mode output_mode = nnet::snn_readout_mode::{output_mode}; | ||
| static const nnet::snn_decision_rule decision_rule = nnet::snn_decision_rule::{decision_rule}; | ||
| typedef {membrane_t.name} membrane_t; | ||
| }};\n""" | ||
|
|
||
| readout_function_template = 'nnet::snn_readout<{input_t}, {output_t}, {config}>({input}, {output});' | ||
|
|
||
|
|
||
| class SNNReadoutConfigTemplate(LayerConfigTemplate): | ||
| def __init__(self): | ||
| super().__init__(SNNReadout) | ||
| self.template = readout_config_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_config_params(node) | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| class SNNReadoutFunctionTemplate(FunctionCallTemplate): | ||
| def __init__(self): | ||
| super().__init__(SNNReadout, include_header=snn_include_list) | ||
| self.template = readout_function_template | ||
|
|
||
| def format(self, node): | ||
| params = self._default_function_params(node) | ||
| return self.template.format(**params) | ||
|
|
||
|
|
||
| def register_snn_templates(backend): | ||
| backend.register_template(IFNeuronConfigTemplate) | ||
| backend.register_template(IFNeuronFunctionTemplate) | ||
| backend.register_template(LIFNeuronConfigTemplate) | ||
| backend.register_template(LIFNeuronFunctionTemplate) | ||
| backend.register_template(SNNReadoutConfigTemplate) | ||
| backend.register_template(SNNReadoutFunctionTemplate) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.