diff --git a/.gitignore b/.gitignore index 8fb87927ce..476728453b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ docs/_build docs/autodoc/* hls4mlprj_* *~ +*.ipynb *.ipynb_checkpoints/ *.bak diff --git a/docs/advanced/snn.rst b/docs/advanced/snn.rst new file mode 100644 index 0000000000..f5f7ccc208 --- /dev/null +++ b/docs/advanced/snn.rst @@ -0,0 +1,201 @@ +======================================== +Spiking Neural Networks (PyTorch/SNN) +======================================== + +This page describes the initial SNN support in the PyTorch frontend. + +Install the SNN frontend dependencies with: + +.. code-block:: bash + + pip install hls4ml[snn] + +Backend support +=============== + +The SNN flow currently supports only the ``Vitis`` backend. + +Execution model +=============== + +Current hls4ml SNN implementations are synchronous (clock-driven). Neuron state +updates and layer computations run in standard HLS pipelines/streams each cycle +according to interface handshakes. The generated design is not a native +asynchronous/event-routed neuromorphic architecture (yet!). + +Reuse factor support +==================== + +Standard hls4ml layers used inside an SNN, such as ``Dense``/linear layers, +retain their normal ``ReuseFactor`` support. ``ReuseFactor`` can still be set at +the model, layer type, or layer name level for these layers, and each dense layer +uses its own configured value independently of the surrounding spiking neuron +layers. The spiking neuron kernels themselves, ``IFNeuron`` and ``LIFNeuron``, do not +currently expose ``ReuseFactor``. They process one timestep at a time, keep +internal membrane state across timesteps, and unroll the per-neuron update loop +across ``n_out`` channels. + +Supported PyTorch modules and readout wrappers +============================================== + +The frontend currently supports direct parsing of: + +* ``Leaky`` -> ``LIFNeuron`` (or ``IFNeuron`` when ``beta`` is effectively 1) + +``SNNReadout`` is an hls4ml layer, not a ``snntorch`` module. To use the +built-in hls4ml readout from a PyTorch model, instantiate the provided PyTorch +marker module: + +.. code-block:: python + + from hls4ml.contrib.snntorch import SNNReadout + +The marker is an identity in PyTorch and is converted to the hls4ml +``SNNReadout`` layer by the PyTorch frontend. + +`snntorch` tracing +================== + +``snntorch`` modules are treated as leaf modules by the hls4ml PyTorch FX tracer. +This allows conversion models to use ``snntorch.Leaky`` directly without defining +conversion-only wrapper classes. + +For ``Leaky``, the supported reset mechanisms are: + +* ``subtract`` +* ``zero`` + +``threshold`` supports scalar or per-neuron vectors (length ``n_out``) for both ``IFNeuron`` and ``LIFNeuron``. +``beta`` supports scalar or per-neuron vectors for ``LIFNeuron``. + +Conversion selects the most memory-efficient representation automatically: + +* scalar values are emitted as compile-time constants +* per-neuron values are emitted as parameter vectors + +For trainable snntorch parameters, conversion uses the current parameter values from the model +at conversion time. + +Readout and Decision Rules +========================== + +The hls4ml ``SNNReadout`` layer implements programmable per-model decision policies. +By default, ``output_mode="spike"`` preserves the original spike-count behavior: + +* ``argmax_spike_count`` +* ``first_to_threshold`` +* ``threshold_then_argmax`` +* ``binary_logit`` (for binary classifiers with ``n_classes == 2``) + +The layer accumulates class spikes over a window. For most decision rules it emits +a class ID. For ``binary_logit``, it emits a score equal to +``count(class_1) - count(class_0)``. + +For non-spiking readout heads, set ``output_mode="membrane"`` and connect +``SNNReadout`` directly after the final dense/linear layer instead of after a +final spiking neuron. In this mode the readout owns the final membrane state: + +.. code-block:: python + + x = self.fc2(x) + return self.readout(x) + +At each timestep, the generated readout computes: + +.. code-block:: cpp + + mem[i] = beta * mem[i] + input[i]; + +No threshold or reset-on-spike is applied in membrane mode. The supported +membrane decision policies are: + +* ``argmax_membrane`` +* ``binary_logit`` (emits ``mem(class_1) - mem(class_0)`` for binary classifiers) + +This will be explained in a tutorial in the hls4ml-tutorials repo. + +Do not place a final spiking neuron before ``SNNReadout(output_mode="membrane")`` +unless you intentionally want the readout to consume that neuron's spike output. +The membrane mode does not recover or expose the internal membrane state of a +preceding ``Leaky``/``IFNeuron``/``LIFNeuron`` layer. If a final output neuron +has a learnable ``beta``, that learnable neuron membrane is not the same state +as the readout-owned membrane. The readout uses its own scalar ``beta``. + +When using the default PyTorch parser, the wrapper module should expose these +attributes as needed: + +* ``n_classes`` (defaults to the input feature count if omitted) +* ``window_size`` or ``stream_length`` (defaults to ``1``) +* ``class_threshold`` (defaults to ``1``) +* ``output_mode`` (defaults to ``spike``; use ``membrane`` for readout-owned membrane accumulation) +* ``beta`` (defaults to ``1.0`` for membrane readout) +* ``decision_rule`` (defaults to ``argmax_spike_count``) +* ``reset_policy`` or ``state_reset_policy`` (defaults to ``fixed_window``) + +Window Boundary Semantics +========================= + +The current implementation uses ``window_size`` timesteps as the sequence boundary +for generated HLS. During PyTorch conversion, the first fixed-window +``SNNReadout``'s ``window_size`` is propagated to all converted ``IFNeuron`` and +``LIFNeuron`` layers in the graph. + +At each boundary: + +* the class decision is emitted +* internal readout counters or readout membrane state are reset for the next sequence +* internal ``IFNeuron``/``LIFNeuron`` membrane state is reset for the next sequence + +The reset happens after the final timestep has been processed and has contributed +to the output. This behavior is compatible with fixed-length time windows. + +Only fixed-window reset is implemented in generated layer kernels today. +``state_reset_policy`` accepts future-facing values such as ``tlast``, +``host_pulse``, and ``never``, but the current layer kernels still use fixed +``window_size`` reset behavior. + +Running ``hls_model.predict()`` +============================== + +Compiled SNN models are stateful across top-function calls. For fixed-window +SNN inference, call the compiled model once per timestep and pass exactly +``window_size`` timesteps for each independent sequence: + +.. code-block:: python + + last = None + for step in range(timesteps): + x_step = x_sequence[step].astype("float32")[None, :] + last = hls_model.predict(x_step) + +After the last call in the window, generated HLS resets the neuron and readout +state for the next sequence. Avoid making stray single-timestep ``predict`` +calls before evaluating a sequence, because those calls advance the state. + +For membrane readout, the PyTorch reference should match the generated readout +accumulation: + +.. code-block:: python + + mem = torch.zeros_like(currents[:, 0, :]) + for step in range(currents.shape[1]): + mem = beta * mem + currents[:, step, :] + pred = mem.argmax(dim=1) + +Using only the final dense current, or using spike-count reduction for a +membrane readout, does not match generated HLS behavior. + +Precision note +============== + +Membrane readout accumulates dense currents over the full window, so very narrow +fixed-point types can reduce accuracy even when the floating-point PyTorch model +looks good. + +``TLAST`` note +============== + +True AXI sideband ``TLAST`` boundary handling requires top-level writer/interface support for packetized AXI stream types. +The current implementation does not yet expose ``TLAST`` to layer kernels directly. + +For variable-length windows, a practical workaround is to keep the hls4ml core unchanged and perform ``TLAST`` to boundary conversion in a thin wrapper IP around the generated project. diff --git a/docs/frontend/pytorch.rst b/docs/frontend/pytorch.rst index 6e91d0c44e..3d1681bab3 100644 --- a/docs/frontend/pytorch.rst +++ b/docs/frontend/pytorch.rst @@ -18,3 +18,6 @@ of the model. If the ``io_parallel`` I/O type (see :ref:`Concepts`) is used, a t Outputs are not transposed back by default, but in ``io_parallel`` case, a transpose node can be added. If not needed, these adjustments can also be switched off. See :py:class:`~hls4ml.utils.config.config_from_pytorch_model` for details. The equivalent of Keras extension API is not yet available for PyTorch parser, and will be provided in the future. + +.. note:: + Experimental spiking layer support is available for selected modules. See :doc:`../advanced/snn` for details. diff --git a/docs/index.rst b/docs/index.rst index f170ca6858..35e79f75bb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,6 +51,7 @@ advanced/precision advanced/fifo_depth advanced/extension + advanced/snn advanced/model_optimization advanced/bramfactor advanced/plugins diff --git a/hls4ml/backends/vivado/passes/snn_templates.py b/hls4ml/backends/vivado/passes/snn_templates.py new file mode 100644 index 0000000000..8eb5c1b69f --- /dev/null +++ b/hls4ml/backends/vivado/passes/snn_templates.py @@ -0,0 +1,128 @@ +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import IFNeuron, LIFNeuron, SNNReadout + +if_config_template = """struct config{index} : nnet::if_neuron_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned window_size = {window_size}; + static const bool threshold_is_vector = {threshold_is_vector}; + static constexpr float threshold = {threshold}; + static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism}; + typedef {threshold_t.name} threshold_t; + typedef {membrane_t.name} membrane_t; +}};\n""" + +if_function_template = 'nnet::if_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {threshold});' +snn_include_list = ['nnet_utils/nnet_snn.h', 'nnet_utils/nnet_snn_stream.h'] + + +class IFNeuronConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(IFNeuron) + self.template = if_config_template + + def format(self, node): + params = self._default_config_params(node) + params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false' + return self.template.format(**params) + + +class IFNeuronFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(IFNeuron, include_header=snn_include_list) + self.template = if_function_template + + def format(self, node): + params = self._default_function_params(node) + params['threshold'] = ( + node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr' + ) + return self.template.format(**params) + + +lif_config_template = """struct config{index} : nnet::lif_neuron_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned window_size = {window_size}; + static const bool beta_is_vector = {beta_is_vector}; + static const bool threshold_is_vector = {threshold_is_vector}; + static constexpr float threshold = {threshold}; + static constexpr float beta = {beta}; + static const nnet::snn_reset_mode reset_mode = nnet::snn_reset_mode::{reset_mechanism}; + typedef {beta_t.name} beta_t; + typedef {threshold_t.name} threshold_t; + typedef {membrane_t.name} membrane_t; +}};\n""" + +lif_function_template = 'nnet::lif_neuron<{input_t}, {output_t}, {config}>({input}, {output}, {beta}, {threshold});' + + +class LIFNeuronConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(LIFNeuron) + self.template = lif_config_template + + def format(self, node): + params = self._default_config_params(node) + params['beta_is_vector'] = 'true' if node.get_attr('beta_mode', 'scalar') == 'vector' else 'false' + params['threshold_is_vector'] = 'true' if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'false' + return self.template.format(**params) + + +class LIFNeuronFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(LIFNeuron, include_header=snn_include_list) + self.template = lif_function_template + + def format(self, node): + params = self._default_function_params(node) + params['beta'] = node.get_weights('beta_vec').name if node.get_attr('beta_mode', 'scalar') == 'vector' else 'nullptr' + params['threshold'] = ( + node.get_weights('threshold_vec').name if node.get_attr('threshold_mode', 'scalar') == 'vector' else 'nullptr' + ) + return self.template.format(**params) + + +readout_config_template = """struct config{index} : nnet::snn_readout_config {{ + static const unsigned n_classes = {n_classes}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned window_size = {window_size}; + static const unsigned class_threshold = {class_threshold}; + static constexpr float beta = {beta}; + static const nnet::snn_readout_mode output_mode = nnet::snn_readout_mode::{output_mode}; + static const nnet::snn_decision_rule decision_rule = nnet::snn_decision_rule::{decision_rule}; + typedef {membrane_t.name} membrane_t; +}};\n""" + +readout_function_template = 'nnet::snn_readout<{input_t}, {output_t}, {config}>({input}, {output});' + + +class SNNReadoutConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(SNNReadout) + self.template = readout_config_template + + def format(self, node): + params = self._default_config_params(node) + return self.template.format(**params) + + +class SNNReadoutFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(SNNReadout, include_header=snn_include_list) + self.template = readout_function_template + + def format(self, node): + params = self._default_function_params(node) + return self.template.format(**params) + + +def register_snn_templates(backend): + backend.register_template(IFNeuronConfigTemplate) + backend.register_template(IFNeuronFunctionTemplate) + backend.register_template(LIFNeuronConfigTemplate) + backend.register_template(LIFNeuronFunctionTemplate) + backend.register_template(SNNReadoutConfigTemplate) + backend.register_template(SNNReadoutFunctionTemplate) diff --git a/hls4ml/contrib/snntorch.py b/hls4ml/contrib/snntorch.py new file mode 100644 index 0000000000..e52797cafb --- /dev/null +++ b/hls4ml/contrib/snntorch.py @@ -0,0 +1,63 @@ +import torch + +from hls4ml.utils.torch import HLS4MLModule + + +class SNNReadout(HLS4MLModule): + """PyTorch marker module for the hls4ml SNNReadout layer. + + In PyTorch this module is an identity. During conversion, hls4ml lowers it + to the built-in SNNReadout IR layer and backend implementation. + """ + + VALID_OUTPUT_MODES = ('spike', 'membrane') + VALID_DECISION_RULES = ( + 'argmax_spike_count', + 'first_to_threshold', + 'threshold_then_argmax', + 'binary_logit', + 'argmax_membrane', + ) + + def __init__( + self, + n_classes=None, + window_size=1, + stream_length=None, + decision_rule=None, + class_threshold=1, + output_mode='spike', + beta=1.0, + reset_policy='fixed_window', + ): + super().__init__() + + output_mode = str(output_mode).lower() + if output_mode not in self.VALID_OUTPUT_MODES: + raise ValueError(f'Unsupported SNNReadout output_mode "{output_mode}". Supported: spike, membrane.') + + if decision_rule is None: + decision_rule = 'argmax_membrane' if output_mode == 'membrane' else 'argmax_spike_count' + decision_rule = str(decision_rule) + if decision_rule not in self.VALID_DECISION_RULES: + raise ValueError( + f'Unsupported SNNReadout decision_rule "{decision_rule}". Supported: {", ".join(self.VALID_DECISION_RULES)}.' + ) + if output_mode == 'membrane' and decision_rule not in ('argmax_membrane', 'binary_logit'): + raise ValueError('SNNReadout membrane mode supports decision_rule "argmax_membrane" or "binary_logit".') + if output_mode == 'spike' and decision_rule == 'argmax_membrane': + raise ValueError('SNNReadout decision_rule "argmax_membrane" requires output_mode "membrane".') + + self.n_classes = n_classes + if stream_length is None: + self.window_size = int(window_size) + else: + self.stream_length = int(stream_length) + self.decision_rule = decision_rule + self.class_threshold = int(class_threshold) + self.output_mode = output_mode + self.beta = torch.tensor(float(beta)) + self.reset_policy = str(reset_policy) + + def forward(self, x): + return x diff --git a/hls4ml/converters/pytorch/snn.py b/hls4ml/converters/pytorch/snn.py new file mode 100644 index 0000000000..b3a09b0f25 --- /dev/null +++ b/hls4ml/converters/pytorch/snn.py @@ -0,0 +1,147 @@ +import numpy as np + +from hls4ml.converters.pytorch_to_hls import pytorch_handler + +# Treat numerically unit leak as IF behavior. +BETA_TO_IF_TOL = 1e-6 + + +def _to_numpy(value, name): + if value is None: + raise Exception(f'Missing SNN parameter: {name}') + + if hasattr(value, 'detach'): + value = value.detach().cpu().numpy() + + try: + return np.asarray(value, dtype=np.float32) + except (TypeError, ValueError) as err: + raise Exception(f'Could not parse "{name}" as numpy array: {value}') from err + + +def _parse_scalar_or_vector(class_object, name, n_out): + value = getattr(class_object, name, None) + arr = _to_numpy(value, name) + flat = arr.reshape(-1) + + if flat.size == 1: + scalar = float(flat[0]) + return {'mode': 'scalar', 'scalar': scalar, 'vector': None} + + if flat.size == n_out: + if np.allclose(flat, flat[0], rtol=0.0, atol=BETA_TO_IF_TOL): + scalar = float(flat[0]) + return {'mode': 'scalar', 'scalar': scalar, 'vector': None} + return {'mode': 'vector', 'scalar': None, 'vector': flat.astype(np.float32)} + + raise Exception(f'Only scalar or length-{n_out} "{name}" is supported for SNN conversion, got shape {arr.shape}') + + +def _parse_reset_mechanism(class_object): + reset = getattr(class_object, 'reset_mechanism', 'subtract') + reset = str(reset).lower() + if reset not in ['subtract', 'zero']: + raise Exception(f'Unsupported reset mechanism "{reset}". Supported: "subtract", "zero".') + return reset + + +def _parse_state_reset_policy(class_object): + policy = getattr(class_object, 'state_reset_policy', None) + if policy is None: + policy = getattr(class_object, 'reset_policy', 'fixed_window') + policy = str(policy).lower() + if policy not in ['fixed_window', 'tlast', 'host_pulse', 'never']: + raise Exception( + f'Unsupported state reset policy "{policy}". Supported: "fixed_window", "tlast", "host_pulse", "never".' + ) + return policy + + +def _parse_readout_beta(class_object): + beta = getattr(class_object, 'beta', 1.0) + arr = _to_numpy(beta, 'beta').reshape(-1) + if arr.size != 1: + raise Exception(f'Only scalar "beta" is supported for SNNReadout membrane mode, got shape {arr.shape}') + return float(arr[0]) + + +@pytorch_handler('Leaky') +def parse_lif_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'Leaky' + + n_out = input_shapes[0][-1] + beta = _parse_scalar_or_vector(class_object, 'beta', n_out) + threshold = _parse_scalar_or_vector(class_object, 'threshold', n_out) + + layer = {} + use_if = beta['mode'] == 'scalar' and np.isclose(beta['scalar'], 1.0, rtol=0.0, atol=BETA_TO_IF_TOL) + layer['class_name'] = 'IFNeuron' if use_if else 'LIFNeuron' + layer['name'] = layer_name + layer['inputs'] = input_names + layer['n_in'] = n_out + layer['n_out'] = n_out + layer['threshold_mode'] = threshold['mode'] + if threshold['mode'] == 'scalar': + layer['threshold'] = threshold['scalar'] + else: + layer['threshold'] = 0.0 + layer['threshold_data'] = threshold['vector'] + + if layer['class_name'] == 'LIFNeuron': + layer['beta_mode'] = beta['mode'] + if beta['mode'] == 'scalar': + layer['beta'] = beta['scalar'] + else: + layer['beta'] = 0.0 + layer['beta_data'] = beta['vector'] + layer['reset_mechanism'] = _parse_reset_mechanism(class_object) + + return layer, [shape for shape in input_shapes[0]] + + +@pytorch_handler('SNNReadout') +def parse_snn_readout_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'SNNReadout' + + layer = {} + layer['class_name'] = 'SNNReadout' + layer['name'] = layer_name + layer['inputs'] = input_names + + n_classes = getattr(class_object, 'n_classes', None) + if n_classes is None: + n_classes = input_shapes[0][-1] + layer['n_classes'] = int(n_classes) + if hasattr(class_object, 'stream_length'): + layer['window_size'] = int(class_object.stream_length) + else: + layer['window_size'] = int(getattr(class_object, 'window_size', 1)) + layer['class_threshold'] = int(getattr(class_object, 'class_threshold', 1)) + layer['output_mode'] = str(getattr(class_object, 'output_mode', 'spike')).lower() + if layer['output_mode'] not in ['spike', 'membrane']: + raise Exception(f'Unsupported SNNReadout output mode "{layer["output_mode"]}". Supported: spike, membrane.') + layer['beta'] = _parse_readout_beta(class_object) + default_decision_rule = 'argmax_membrane' if layer['output_mode'] == 'membrane' else 'argmax_spike_count' + layer['decision_rule'] = str(getattr(class_object, 'decision_rule', default_decision_rule)) + layer['state_reset_policy'] = _parse_state_reset_policy(class_object) + if layer['decision_rule'] not in [ + 'argmax_spike_count', + 'first_to_threshold', + 'threshold_then_argmax', + 'binary_logit', + 'argmax_membrane', + ]: + raise Exception( + f'Unsupported SNN decision rule "{layer["decision_rule"]}". ' + 'Supported: argmax_spike_count, first_to_threshold, threshold_then_argmax, binary_logit, argmax_membrane.' + ) + if layer['decision_rule'] == 'binary_logit' and layer['n_classes'] != 2: + raise Exception('binary_logit decision rule requires n_classes == 2') + if layer['output_mode'] == 'membrane' and layer['decision_rule'] not in ['argmax_membrane', 'binary_logit']: + raise Exception('SNNReadout membrane mode supports decision_rule "argmax_membrane" or "binary_logit".') + if layer['output_mode'] == 'spike' and layer['decision_rule'] == 'argmax_membrane': + raise Exception('SNNReadout decision_rule "argmax_membrane" requires output_mode "membrane".') + + output_shape = input_shapes[0][:] + output_shape[-1] = 1 + return layer, output_shape diff --git a/hls4ml/model/attributes.py b/hls4ml/model/attributes.py index 01e399e9ba..858766fd28 100644 --- a/hls4ml/model/attributes.py +++ b/hls4ml/model/attributes.py @@ -204,6 +204,9 @@ def __iter__(self): precision_keys = [k for k, v in self.attributes.items() if isinstance(v, self.clazz)] yield from precision_keys + def __contains__(self, key): + return key in self.attributes and isinstance(self.attributes[key], self.clazz) + def __setitem__(self, key, value): self.attributes[key] = value diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 8bd8cd8a11..42afbabd3c 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1037,6 +1037,115 @@ def initialize(self): super().initialize() +class IFNeuron(Layer): + _expected_attributes = [ + Attribute('n_in'), + Attribute('n_out'), + Attribute('window_size', value_type=int, default=0, configurable=False), + Attribute('threshold', value_type=float), + Attribute('threshold_mode', value_type=str, default='scalar'), + ChoiceAttribute('reset_mechanism', choices=['subtract', 'zero'], default='subtract', configurable=False), + TypeAttribute('threshold'), + TypeAttribute('membrane'), + ] + + def initialize(self): + shape = list(self.get_input_variable().shape) + shape[-1] = self.attributes['n_out'] + self.add_output_variable(shape) + self._set_type_t('threshold') + self._set_type_t('membrane') + if self.get_attr('threshold_mode', 'scalar') == 'vector': + threshold = self.get_attr('threshold_data') + if threshold is None or len(threshold) != self.attributes['n_out']: + raise Exception('IFNeuron threshold vector must be present and have length n_out') + threshold_t = self.get_attr('threshold_t') + self.add_weights_variable( + name='threshold_vec', + var_name='th{index}', + type_name=threshold_t.name, + precision=threshold_t.precision, + data=threshold, + ) + + +class LIFNeuron(Layer): + _expected_attributes = [ + Attribute('n_in'), + Attribute('n_out'), + Attribute('window_size', value_type=int, default=0, configurable=False), + Attribute('threshold', value_type=float), + Attribute('threshold_mode', value_type=str, default='scalar'), + Attribute('beta', value_type=float), + Attribute('beta_mode', value_type=str, default='scalar'), + ChoiceAttribute('reset_mechanism', choices=['subtract', 'zero'], default='subtract', configurable=False), + TypeAttribute('beta'), + TypeAttribute('threshold'), + TypeAttribute('membrane'), + ] + + def initialize(self): + shape = list(self.get_input_variable().shape) + shape[-1] = self.attributes['n_out'] + self.add_output_variable(shape) + self._set_type_t('beta') + self._set_type_t('threshold') + self._set_type_t('membrane') + if self.get_attr('threshold_mode', 'scalar') == 'vector': + threshold = self.get_attr('threshold_data') + if threshold is None or len(threshold) != self.attributes['n_out']: + raise Exception('LIFNeuron threshold vector must be present and have length n_out') + threshold_t = self.get_attr('threshold_t') + self.add_weights_variable( + name='threshold_vec', + var_name='th{index}', + type_name=threshold_t.name, + precision=threshold_t.precision, + data=threshold, + ) + if self.get_attr('beta_mode', 'scalar') == 'vector': + beta = self.get_attr('beta_data') + if beta is None or len(beta) != self.attributes['n_out']: + raise Exception('LIFNeuron beta vector must be present and have length n_out') + beta_t = self.get_attr('beta_t') + self.add_weights_variable( + name='beta_vec', + var_name='be{index}', + type_name=beta_t.name, + precision=beta_t.precision, + data=beta, + ) + + +class SNNReadout(Layer): + _expected_attributes = [ + Attribute('n_classes'), + Attribute('window_size', value_type=int, default=1), + Attribute('class_threshold', value_type=int, default=1), + Attribute('beta', value_type=float, default=1.0), + ChoiceAttribute('output_mode', choices=['spike', 'membrane'], default='spike'), + ChoiceAttribute( + 'state_reset_policy', + choices=['fixed_window', 'tlast', 'host_pulse', 'never'], + default='fixed_window', + configurable=False, + ), + ChoiceAttribute( + 'decision_rule', + choices=['argmax_spike_count', 'first_to_threshold', 'threshold_then_argmax', 'binary_logit', 'argmax_membrane'], + default='argmax_membrane', + configurable=False, + ), + TypeAttribute('membrane'), + ] + + def initialize(self): + shape = list(self.get_input_variable().shape) + shape[-1] = 1 + self.add_output_variable(shape) + self._set_type_t('membrane') + + class BatchNormOnnx(Layer): """ A transient layer formed from ONNX BatchNormalization that gets converted to @@ -1793,6 +1902,9 @@ def initialize(self): 'ELU': ParametrizedActivation, 'PReLU': PReLU, 'Softmax': Softmax, + 'IFNeuron': IFNeuron, + 'LIFNeuron': LIFNeuron, + 'SNNReadout': SNNReadout, 'TernaryTanh': TernaryTanh, 'HardActivation': HardActivation, 'Reshape': Reshape, diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 9b3a5dce90..76d9d0ebae 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -82,6 +82,7 @@ 'bit_exact', 'fuse_fixed_point_quantizer', 'fix_input_precision', + 'propagate_snn_readout_window_size', 'eliminate_linear_activation', 'merge_linear_activation', # many of the above optimzers need to be done before this diff --git a/hls4ml/model/optimizer/passes/snn.py b/hls4ml/model/optimizer/passes/snn.py new file mode 100644 index 0000000000..f432bbb73d --- /dev/null +++ b/hls4ml/model/optimizer/passes/snn.py @@ -0,0 +1,31 @@ +from hls4ml.model.layers import IFNeuron, LIFNeuron, SNNReadout +from hls4ml.model.optimizer import ModelOptimizerPass + + +class PropagateSNNReadoutWindowSize(ModelOptimizerPass): + """ + Propagate fixed-window SNN readout length to upstream neuron layers. + """ + + name = 'propagate_snn_readout_window_size' + + def __init__(self): + pass + + def transform(self, model): + readouts = [ + node + for node in model.graph.values() + if isinstance(node, SNNReadout) and node.get_attr('state_reset_policy', 'fixed_window') == 'fixed_window' + ] + if len(readouts) == 0: + return False + + window_size = readouts[0].get_attr('window_size', 0) + changed = False + for node in model.graph.values(): + if isinstance(node, (IFNeuron, LIFNeuron)) and node.get_attr('window_size') != window_size: + node.set_attr('window_size', window_size) + changed = True + + return changed diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_snn.h b/hls4ml/templates/vivado/nnet_utils/nnet_snn.h new file mode 100644 index 0000000000..f7ef2cfc0c --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_snn.h @@ -0,0 +1,300 @@ +#ifndef NNET_SNN_H_ +#define NNET_SNN_H_ + +#include "nnet_common.h" + +namespace nnet { + +enum class snn_reset_mode { subtract, zero }; +enum class snn_decision_rule { + argmax_spike_count, + first_to_threshold, + threshold_then_argmax, + binary_logit, + argmax_membrane +}; +enum class snn_readout_mode { spike, membrane }; + +struct if_neuron_config { + static const unsigned n_in = 1; + static const unsigned n_out = 1; + static const unsigned io_type = io_parallel; + static const unsigned window_size = 0; + static const bool threshold_is_vector = false; + static constexpr float threshold = 1.0; + static const snn_reset_mode reset_mode = snn_reset_mode::subtract; + typedef float threshold_t; + typedef float membrane_t; +}; + +struct lif_neuron_config { + static const unsigned n_in = 1; + static const unsigned n_out = 1; + static const unsigned io_type = io_parallel; + static const unsigned window_size = 0; + static const bool beta_is_vector = false; + static const bool threshold_is_vector = false; + static constexpr float threshold = 1.0; + static constexpr float beta = 0.9; + static const snn_reset_mode reset_mode = snn_reset_mode::subtract; + typedef float beta_t; + typedef float threshold_t; + typedef float membrane_t; +}; + +struct snn_readout_config { + static const unsigned n_classes = 2; + static const unsigned io_type = io_parallel; + static const unsigned window_size = 1; + static const unsigned class_threshold = 1; + static constexpr float beta = 1.0; + static const snn_readout_mode output_mode = snn_readout_mode::spike; + static const snn_decision_rule decision_rule = snn_decision_rule::argmax_spike_count; + typedef float membrane_t; +}; + +template +void if_neuron(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + const typename CONFIG_T::threshold_t threshold_vec[CONFIG_T::n_out]) { + #pragma HLS PIPELINE II=1 + + // Static state persists across calls until the configured time window ends. + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + typename CONFIG_T::threshold_t threshold = + CONFIG_T::threshold_is_vector ? threshold_vec[i] : (typename CONFIG_T::threshold_t)CONFIG_T::threshold; + typename CONFIG_T::membrane_t v = mem[i] + (typename CONFIG_T::membrane_t)data[i]; + bool spike = (v >= threshold); + if (spike) { + if (CONFIG_T::reset_mode == snn_reset_mode::subtract) { + v = v - threshold; + } else { + v = 0; + } + res[i] = 1; + } else { + res[i] = 0; + } + mem[i] = v; + } + + if (CONFIG_T::window_size > 0) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + mem[i] = 0; + } + } + } +} + +template +void lif_neuron(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + const typename CONFIG_T::beta_t beta_vec[CONFIG_T::n_out], + const typename CONFIG_T::threshold_t threshold_vec[CONFIG_T::n_out]) { + #pragma HLS PIPELINE II=1 + + // Static state persists across calls until the configured time window ends. + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + typename CONFIG_T::beta_t beta = CONFIG_T::beta_is_vector ? beta_vec[i] : (typename CONFIG_T::beta_t)CONFIG_T::beta; + typename CONFIG_T::threshold_t threshold = + CONFIG_T::threshold_is_vector ? threshold_vec[i] : (typename CONFIG_T::threshold_t)CONFIG_T::threshold; + // LIF update: v[t] = beta * v[t-1] + input. + typename CONFIG_T::membrane_t v = + (typename CONFIG_T::membrane_t)(beta * mem[i]) + (typename CONFIG_T::membrane_t)data[i]; + bool spike = (v >= threshold); + if (spike) { + if (CONFIG_T::reset_mode == snn_reset_mode::subtract) { + v = v - threshold; + } else { + v = 0; + } + res[i] = 1; + } else { + res[i] = 0; + } + mem[i] = v; + } + + if (CONFIG_T::window_size > 0) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + mem[i] = 0; + } + } + } +} + +template void reset_snn_counts(unsigned counts[CONFIG_T::n_classes]) { + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + counts[i] = 0; + } +} + +template void reset_snn_membrane(typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + mem[i] = 0; + } +} + +template void advance_snn_count_window(unsigned &ts, unsigned counts[CONFIG_T::n_classes]) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + reset_snn_counts(counts); + } +} + +template +void advance_snn_membrane_window(unsigned &ts, typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + reset_snn_membrane(mem); + } +} + +template +void update_snn_counts_array(data_T data[CONFIG_T::n_classes], unsigned counts[CONFIG_T::n_classes]) { + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + counts[i] += (data[i] != 0) ? 1 : 0; + } +} + +template unsigned argmax_snn_counts(unsigned counts[CONFIG_T::n_classes]) { + unsigned best = 0; + for (unsigned i = 1; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + if (counts[i] > counts[best]) { + best = i; + } + } + return best; +} + +template unsigned first_snn_threshold(unsigned counts[CONFIG_T::n_classes], unsigned fallback) { + unsigned best = fallback; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + if (counts[i] >= CONFIG_T::class_threshold) { + best = i; + break; + } + } + return best; +} + +template unsigned threshold_then_snn_argmax(unsigned counts[CONFIG_T::n_classes], unsigned fallback) { + bool any_reached = false; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + any_reached |= (counts[i] >= CONFIG_T::class_threshold); + } + + if (any_reached) { + return first_snn_threshold(counts, fallback); + } + + return fallback; +} + +template +unsigned update_snn_membrane_array(data_T data[CONFIG_T::n_classes], + typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + unsigned best = 0; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + typename CONFIG_T::membrane_t v = + (typename CONFIG_T::membrane_t)((typename CONFIG_T::membrane_t)CONFIG_T::beta * mem[i]) + + (typename CONFIG_T::membrane_t)data[i]; + mem[i] = v; + if (i == 0 || v > mem[best]) { + best = i; + } + } + return best; +} + +template +typename res_T::value_type snn_membrane_readout_value(data_T data[CONFIG_T::n_classes], + typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + unsigned best = update_snn_membrane_array(data, mem); + if (CONFIG_T::decision_rule == snn_decision_rule::binary_logit) { + return (typename res_T::value_type)(mem[1] - mem[0]); + } + return (typename res_T::value_type)best; +} + +template +typename res_T::value_type snn_spike_readout_value(data_T data[CONFIG_T::n_classes], unsigned counts[CONFIG_T::n_classes]) { + if (CONFIG_T::decision_rule == snn_decision_rule::argmax_spike_count) { + unsigned best = 0; + unsigned best_count = 0; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + unsigned c = counts[i] + ((data[i] != 0) ? 1 : 0); + counts[i] = c; + if (i == 0 || c > best_count) { + best = i; + best_count = c; + } + } + return (typename res_T::value_type)best; + } + + update_snn_counts_array(data, counts); + + if (CONFIG_T::decision_rule == snn_decision_rule::binary_logit) { + return (typename res_T::value_type)((int)counts[1] - (int)counts[0]); + } + + unsigned best = argmax_snn_counts(counts); + if (CONFIG_T::decision_rule == snn_decision_rule::first_to_threshold) { + best = first_snn_threshold(counts, best); + } else if (CONFIG_T::decision_rule == snn_decision_rule::threshold_then_argmax) { + best = threshold_then_snn_argmax(counts, best); + } + + return (typename res_T::value_type)best; +} + +template void snn_readout(data_T data[CONFIG_T::n_classes], res_T res[1]) { + #pragma HLS PIPELINE II=1 + + // Counts and membrane values persist across calls within one readout window. + static unsigned counts[CONFIG_T::n_classes]; + #pragma HLS ARRAY_PARTITION variable=counts complete + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + if (CONFIG_T::output_mode == snn_readout_mode::membrane) { + res[0] = snn_membrane_readout_value(data, mem); + advance_snn_membrane_window(ts, mem); + return; + } + + res[0] = snn_spike_readout_value(data, counts); + advance_snn_count_window(ts, counts); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_snn_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_snn_stream.h new file mode 100644 index 0000000000..1647f5df65 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_snn_stream.h @@ -0,0 +1,215 @@ +#ifndef NNET_SNN_STREAM_H_ +#define NNET_SNN_STREAM_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_snn.h" + +namespace nnet { + +template void zero_snn_pack(res_T &out_pack) { + for (unsigned i = 0; i < res_T::size; i++) { + #pragma HLS UNROLL + out_pack[i] = 0; + } +} + +template +void update_snn_counts_pack(data_T in_pack, unsigned counts[CONFIG_T::n_classes]) { + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + counts[i] += (in_pack[i] != 0) ? 1 : 0; + } +} + +template +unsigned update_snn_membrane_pack(data_T in_pack, typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + unsigned best = 0; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + typename CONFIG_T::membrane_t v = + (typename CONFIG_T::membrane_t)((typename CONFIG_T::membrane_t)CONFIG_T::beta * mem[i]) + + (typename CONFIG_T::membrane_t)in_pack[i]; + mem[i] = v; + if (i == 0 || v > mem[best]) { + best = i; + } + } + return best; +} + +template +typename res_T::value_type snn_membrane_readout_value_pack(data_T in_pack, + typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]) { + unsigned best = update_snn_membrane_pack(in_pack, mem); + if (CONFIG_T::decision_rule == snn_decision_rule::binary_logit) { + return (typename res_T::value_type)(mem[1] - mem[0]); + } + return (typename res_T::value_type)best; +} + +template +typename res_T::value_type snn_spike_readout_value_pack(data_T in_pack, unsigned counts[CONFIG_T::n_classes]) { + if (CONFIG_T::decision_rule == snn_decision_rule::argmax_spike_count) { + unsigned best = 0; + unsigned best_count = 0; + for (unsigned i = 0; i < CONFIG_T::n_classes; i++) { + #pragma HLS UNROLL + unsigned c = counts[i] + ((in_pack[i] != 0) ? 1 : 0); + counts[i] = c; + if (i == 0 || c > best_count) { + best = i; + best_count = c; + } + } + return (typename res_T::value_type)best; + } + + update_snn_counts_pack(in_pack, counts); + + if (CONFIG_T::decision_rule == snn_decision_rule::binary_logit) { + return (typename res_T::value_type)((int)counts[1] - (int)counts[0]); + } + + unsigned best = argmax_snn_counts(counts); + if (CONFIG_T::decision_rule == snn_decision_rule::first_to_threshold) { + best = first_snn_threshold(counts, best); + } else if (CONFIG_T::decision_rule == snn_decision_rule::threshold_then_argmax) { + best = threshold_then_snn_argmax(counts, best); + } + + return (typename res_T::value_type)best; +} + +template +void if_neuron(hls::stream &data_stream, hls::stream &res_stream, + const typename CONFIG_T::threshold_t threshold_vec[CONFIG_T::n_out]) { + #pragma HLS PIPELINE II=1 + + // Static state persists across calls until the configured time window ends. + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + data_T in_pack = data_stream.read(); + res_T out_pack; + PRAGMA_DATA_PACK(out_pack) + + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + typename CONFIG_T::threshold_t threshold = + CONFIG_T::threshold_is_vector ? threshold_vec[i] : (typename CONFIG_T::threshold_t)CONFIG_T::threshold; + typename CONFIG_T::membrane_t v = mem[i] + (typename CONFIG_T::membrane_t)in_pack[i]; + bool spike = (v >= threshold); + if (spike) { + if (CONFIG_T::reset_mode == snn_reset_mode::subtract) { + v = v - threshold; + } else { + v = 0; + } + out_pack[i] = 1; + } else { + out_pack[i] = 0; + } + mem[i] = v; + } + + res_stream.write(out_pack); + + if (CONFIG_T::window_size > 0) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + mem[i] = 0; + } + } + } +} + +template +void lif_neuron(hls::stream &data_stream, hls::stream &res_stream, + const typename CONFIG_T::beta_t beta_vec[CONFIG_T::n_out], + const typename CONFIG_T::threshold_t threshold_vec[CONFIG_T::n_out]) { + #pragma HLS PIPELINE II=1 + + // Static state persists across calls until the configured time window ends. + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + data_T in_pack = data_stream.read(); + res_T out_pack; + PRAGMA_DATA_PACK(out_pack) + + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + typename CONFIG_T::beta_t beta = CONFIG_T::beta_is_vector ? beta_vec[i] : (typename CONFIG_T::beta_t)CONFIG_T::beta; + typename CONFIG_T::threshold_t threshold = + CONFIG_T::threshold_is_vector ? threshold_vec[i] : (typename CONFIG_T::threshold_t)CONFIG_T::threshold; + // LIF update: v[t] = beta * v[t-1] + input. + typename CONFIG_T::membrane_t v = + (typename CONFIG_T::membrane_t)(beta * mem[i]) + (typename CONFIG_T::membrane_t)in_pack[i]; + bool spike = (v >= threshold); + if (spike) { + if (CONFIG_T::reset_mode == snn_reset_mode::subtract) { + v = v - threshold; + } else { + v = 0; + } + out_pack[i] = 1; + } else { + out_pack[i] = 0; + } + mem[i] = v; + } + + res_stream.write(out_pack); + + if (CONFIG_T::window_size > 0) { + ts++; + if (ts >= CONFIG_T::window_size) { + ts = 0; + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + #pragma HLS UNROLL + mem[i] = 0; + } + } + } +} + +template +void snn_readout(hls::stream &data_stream, hls::stream &res_stream) { + #pragma HLS PIPELINE II=1 + + // Counts and membrane values persist across calls within one readout window. + static unsigned counts[CONFIG_T::n_classes]; + #pragma HLS ARRAY_PARTITION variable=counts complete + static typename CONFIG_T::membrane_t mem[CONFIG_T::n_classes]; + #pragma HLS ARRAY_PARTITION variable=mem complete + static unsigned ts = 0; + + data_T in_pack = data_stream.read(); + + if (CONFIG_T::output_mode == snn_readout_mode::membrane) { + res_T out_pack; + PRAGMA_DATA_PACK(out_pack) + zero_snn_pack(out_pack); + out_pack[0] = snn_membrane_readout_value_pack(in_pack, mem); + res_stream.write(out_pack); + advance_snn_membrane_window(ts, mem); + return; + } + + res_T out_pack; + PRAGMA_DATA_PACK(out_pack) + zero_snn_pack(out_pack); + out_pack[0] = snn_spike_readout_value_pack(in_pack, counts); + res_stream.write(out_pack); + advance_snn_count_window(ts, counts); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index f77992faf3..36c81aaeac 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -199,7 +199,9 @@ def make_layer_config(layer): elif attr.name == 'reuse_factor': layer_config[attr.config_name] = default_reuse_factor else: - if attr.default is not None: + if attr.name in layer: + layer_config[attr.config_name] = layer[attr.name] + elif attr.default is not None: layer_config[attr.config_name] = attr.default quantizers = {qname: qclass for qname, qclass in layer.items() if 'quantizer' in qname and qclass is not None} @@ -406,7 +408,9 @@ def make_layer_config(layer): elif attr.name == 'reuse_factor': layer_config[attr.config_name] = default_reuse_factor else: - if attr.default is not None: + if attr.name in layer: + layer_config[attr.config_name] = layer[attr.name] + elif attr.default is not None: layer_config[attr.config_name] = attr.default if layer['class_name'] == 'Input': @@ -526,7 +530,9 @@ def make_layer_config(layer): elif attr.name == 'reuse_factor': layer_config[attr.config_name] = default_reuse_factor else: - if attr.default is not None: + if attr.name in layer: + layer_config[attr.config_name] = layer[attr.name] + elif attr.default is not None: layer_config[attr.config_name] = attr.default return layer_config diff --git a/hls4ml/utils/torch.py b/hls4ml/utils/torch.py index 25d2754b1f..21b0306ca8 100644 --- a/hls4ml/utils/torch.py +++ b/hls4ml/utils/torch.py @@ -22,4 +22,5 @@ def is_leaf_module(self, m, module_qualified_name: str) -> bool: or m.__module__.startswith('torch.nn') or m.__module__.startswith('torch.ao.nn') or m.__module__.startswith('brevitas.nn') + or m.__module__.startswith('snntorch') ) and not isinstance(m, torch.nn.Sequential) diff --git a/pyproject.toml b/pyproject.toml index a39c7cb362..cead8d81e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ optional-dependencies.qkeras = [ "tensorflow-model-optimization<=0.7.5", ] optional-dependencies.quartus-report = [ "calmjs-parse", "tabulate" ] +optional-dependencies.snn = [ "snntorch", "torch" ] optional-dependencies.sr = [ "sympy>=1.13.1" ] optional-dependencies.testing = [ "calmjs-parse", diff --git a/test/pytest/test_extensions_pytorch_snn.py b/test/pytest/test_extensions_pytorch_snn.py new file mode 100644 index 0000000000..e7fa4a59fa --- /dev/null +++ b/test/pytest/test_extensions_pytorch_snn.py @@ -0,0 +1,76 @@ +from pathlib import Path + +import pytest +import torch + +import hls4ml +import hls4ml.utils.torch + +test_root_path = Path(__file__).parent + + +class TSNNWindowReadout(hls4ml.utils.torch.HLS4MLModule): + """Example custom PyTorch module mapped to builtin SNNReadout.""" + + def __init__(self, n_classes=4, window_size=8, decision_rule='argmax_spike_count', class_threshold=2): + super().__init__() + self.n_classes = n_classes + self.window_size = window_size + self.decision_rule = decision_rule + self.class_threshold = class_threshold + + def forward(self, x): + return x + + +def parse_custom_snn_readout(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'TSNNWindowReadout' + + layer = {} + layer['class_name'] = 'SNNReadout' + layer['name'] = layer_name + layer['inputs'] = input_names + layer['n_classes'] = int(class_object.n_classes) + layer['window_size'] = int(class_object.window_size) + layer['class_threshold'] = int(class_object.class_threshold) + layer['decision_rule'] = str(class_object.decision_rule) + + output_shape = input_shapes[0][:] + output_shape[-1] = 1 + return layer, output_shape + + +@pytest.fixture(scope='session', autouse=True) +def register_custom_snn_extension(): + hls4ml.converters.register_pytorch_layer_handler('TSNNWindowReadout', parse_custom_snn_readout) + + +def test_extensions_pytorch_snn_readout_parser(test_case_id): + class PyTorchModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.readout = TSNNWindowReadout( + n_classes=4, window_size=10, decision_rule='threshold_then_argmax', class_threshold=3 + ) + + def forward(self, x): + x = self.fc(x) + return self.readout(x) + + pmodel = PyTorchModel() + config = hls4ml.utils.config_from_pytorch_model( + pmodel, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + hmodel = hls4ml.converters.convert_from_pytorch_model( + pmodel, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + readouts = [layer for layer in hmodel.get_layers() if layer.class_name == 'SNNReadout'] + assert len(readouts) == 1 + assert readouts[0].get_attr('decision_rule') == 'threshold_then_argmax' + assert readouts[0].get_attr('window_size') == 10 diff --git a/test/pytest/test_snn_pytorch.py b/test/pytest/test_snn_pytorch.py new file mode 100644 index 0000000000..d130cc9b32 --- /dev/null +++ b/test/pytest/test_snn_pytorch.py @@ -0,0 +1,329 @@ +from pathlib import Path + +import pytest +import torch + +import hls4ml +import hls4ml.contrib.snntorch +import hls4ml.utils.torch + +test_root_path = Path(__file__).parent + + +class Leaky(hls4ml.utils.torch.HLS4MLModule): + def __init__(self, beta=0.9, threshold=1.0, reset_mechanism='subtract'): + super().__init__() + self.beta = torch.tensor(beta) + self.threshold = torch.tensor(threshold) + self.reset_mechanism = reset_mechanism + + def forward(self, x): + return x + + +class LIFNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=0.95, threshold=1.2, reset_mechanism='subtract') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +class IFNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=1.0, threshold=0.8, reset_mechanism='zero') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +class SNNClassifier(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=0.95, threshold=1.2, reset_mechanism='subtract') + self.readout = hls4ml.contrib.snntorch.SNNReadout( + n_classes=4, window_size=12, decision_rule='threshold_then_argmax', class_threshold=3 + ) + + def forward(self, x): + x = self.fc(x) + x = self.neuron(x) + return self.readout(x) + + +class SNNMembraneReadoutClassifier(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.readout = hls4ml.contrib.snntorch.SNNReadout( + n_classes=4, + window_size=12, + decision_rule='argmax_membrane', + output_mode='membrane', + beta=0.9, + ) + + def forward(self, x): + x = self.fc(x) + return self.readout(x) + + +class SNNClassifierWithResetPolicy(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=0.95, threshold=1.2, reset_mechanism='subtract') + self.readout = hls4ml.contrib.snntorch.SNNReadout( + n_classes=4, stream_length=7, decision_rule='first_to_threshold', class_threshold=2, reset_policy='host_pulse' + ) + + def forward(self, x): + x = self.fc(x) + x = self.neuron(x) + return self.readout(x) + + +class IFNetTol(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=1.0 - 5e-7, threshold=0.8, reset_mechanism='subtract') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +class LIFVectorNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=[0.8, 0.9, 0.85, 0.95], threshold=[1.1, 1.0, 0.9, 1.2], reset_mechanism='subtract') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +class IFVectorThresholdNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=1.0, threshold=[0.8, 0.9, 1.0, 1.1], reset_mechanism='zero') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +class LearnedVectorParamsNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=[0.2, 0.3, 0.4, 0.5], threshold=[0.2, 0.3, 0.4, 0.5], reset_mechanism='subtract') + # Simulate learned values being different from initialization at conversion time. + self.neuron.beta = torch.nn.Parameter(torch.tensor([0.72, 0.81, 0.63, 0.94])) + self.neuron.threshold = torch.nn.Parameter(torch.tensor([1.25, 0.95, 1.05, 0.85])) + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + +@pytest.mark.parametrize( + 'model_class,expected_layer', + [ + (LIFNet, 'LIFNeuron'), + (IFNet, 'IFNeuron'), + (IFNetTol, 'IFNeuron'), + (SNNClassifier, 'SNNReadout'), + (SNNMembraneReadoutClassifier, 'SNNReadout'), + ], +) +def test_pytorch_snn_layers_are_parsed(test_case_id, model_class, expected_layer): + model = model_class() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + layer_names = [layer.class_name for layer in hmodel.get_layers()] + assert expected_layer in layer_names + + if expected_layer == 'SNNReadout': + readout = [layer for layer in hmodel.get_layers() if layer.class_name == 'SNNReadout'][0] + if model_class == SNNMembraneReadoutClassifier: + assert readout.get_attr('output_mode') == 'membrane' + assert readout.get_attr('decision_rule') == 'argmax_membrane' + assert readout.get_attr('beta') == pytest.approx(0.9) + else: + assert readout.get_attr('output_mode') == 'spike' + assert readout.get_attr('decision_rule') == 'threshold_then_argmax' + assert readout.get_attr('window_size') == 12 + assert readout.get_attr('class_threshold') == 3 + neuron = [layer for layer in hmodel.get_layers() if layer.class_name in ['IFNeuron', 'LIFNeuron']][0] + assert neuron.get_attr('window_size') == 12 + + +@pytest.mark.parametrize( + 'beta,expected_layer', + [ + (1.0, 'IFNeuron'), + (1.0 - 5e-7, 'IFNeuron'), + (0.999, 'LIFNeuron'), + (0.95, 'LIFNeuron'), + ], +) +def test_leaky_beta_maps_to_if_or_lif(test_case_id, beta, expected_layer): + class BetaNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + self.neuron = Leaky(beta=beta, threshold=1.0, reset_mechanism='subtract') + + def forward(self, x): + x = self.fc(x) + return self.neuron(x) + + model = BetaNet() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + layer_names = [layer.class_name for layer in hmodel.get_layers()] + assert expected_layer in layer_names + + +@pytest.mark.parametrize( + 'model_class,expected_layer,beta_mode,threshold_mode', + [ + (LIFVectorNet, 'LIFNeuron', 'vector', 'vector'), + (IFVectorThresholdNet, 'IFNeuron', None, 'vector'), + (LIFNet, 'LIFNeuron', 'scalar', 'scalar'), + (IFNet, 'IFNeuron', None, 'scalar'), + ], +) +def test_snn_scalar_vs_vector_modes(test_case_id, model_class, expected_layer, beta_mode, threshold_mode): + model = model_class() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + layer = [layer for layer in hmodel.get_layers() if layer.class_name == expected_layer][0] + assert layer.get_attr('threshold_mode') == threshold_mode + if threshold_mode == 'vector': + assert layer.get_weights('threshold_vec').data.shape[0] == layer.get_attr('n_out') + if expected_layer == 'LIFNeuron': + assert layer.get_attr('beta_mode') == beta_mode + if beta_mode == 'vector': + assert layer.get_weights('beta_vec').data.shape[0] == layer.get_attr('n_out') + + +def test_snn_uses_current_parameter_values_for_vector_params(test_case_id): + model = LearnedVectorParamsNet() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + layer = [layer for layer in hmodel.get_layers() if layer.class_name == 'LIFNeuron'][0] + assert layer.get_attr('beta_mode') == 'vector' + assert layer.get_attr('threshold_mode') == 'vector' + assert list(layer.get_weights('beta_vec').data) == pytest.approx([0.72, 0.81, 0.63, 0.94]) + assert list(layer.get_weights('threshold_vec').data) == pytest.approx([1.25, 0.95, 1.05, 0.85]) + + +def test_snn_readout_stream_length_alias_and_reset_policy(test_case_id): + model = SNNClassifierWithResetPolicy() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + readout = [layer for layer in hmodel.get_layers() if layer.class_name == 'SNNReadout'][0] + assert readout.get_attr('window_size') == 7 + assert readout.get_attr('decision_rule') == 'first_to_threshold' + assert readout.get_attr('state_reset_policy') == 'host_pulse' + + +def test_snn_layer_type_config_is_exposed_for_quantization(test_case_id): + model = LIFVectorNet() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + config['LayerName']['neuron']['Precision']['beta'] = 'ap_fixed<12,2>' + config['LayerName']['neuron']['Precision']['threshold'] = 'ap_fixed<10,3>' + config['LayerName']['neuron']['Precision']['membrane'] = 'ap_fixed<14,4>' + + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + layer = [layer for layer in hmodel.get_layers() if layer.class_name == 'LIFNeuron'][0] + assert layer.get_attr('beta_t').precision.definition_cpp() == 'ap_fixed<12,2>' + assert layer.get_attr('threshold_t').precision.definition_cpp() == 'ap_fixed<10,3>' + assert layer.get_attr('membrane_t').precision.definition_cpp() == 'ap_fixed<14,4>' + + +def test_snn_membrane_readout_type_config_is_exposed(test_case_id): + model = SNNMembraneReadoutClassifier() + config = hls4ml.utils.config_from_pytorch_model( + model, (4,), default_precision='ap_fixed<16,6>', granularity='name', backend='Vitis' + ) + config['LayerName']['readout']['Precision']['membrane'] = 'ap_fixed<18,6>' + + hmodel = hls4ml.converters.convert_from_pytorch_model( + model, + output_dir=str(test_root_path / test_case_id), + backend='Vitis', + io_type='io_parallel', + hls_config=config, + ) + + readout = [layer for layer in hmodel.get_layers() if layer.class_name == 'SNNReadout'][0] + assert readout.get_attr('membrane_t').precision.definition_cpp() == 'ap_fixed<18,6>'