Skip to content

Add inline GEMM optimizations and general performance improvements#226

Open
jfsantos wants to merge 5 commits intosdatkinson:mainfrom
jfsantos:feature/inline-gemm
Open

Add inline GEMM optimizations and general performance improvements#226
jfsantos wants to merge 5 commits intosdatkinson:mainfrom
jfsantos:feature/inline-gemm

Conversation

@jfsantos
Copy link
Contributor

Hand-optimized GEMM kernels for small matrices common in NAM models, gated by #ifdef NAM_USE_INLINE_GEMM with Eigen fallback. Includes:

  • Specialized Conv1D kernels: fused 4x4 and 2x2 kernel_size=3, plus fully-unrolled paths for 2x2 through 8x8 channel configurations
  • Conv1x1 inline specializations for all common size combinations
  • FiLM inline path with 4-element loop unrolling
  • GatingActivation/BlendingActivation inline paths
  • Branchless hardswish, 4-element loop unrolling for all activations
  • SiLU added to LUT enable/disable
  • Ring buffer refactored to Eigen block operations
  • memcpy replacements for pure copy operations in wavenet
  • Optimized single-channel output path in WaveNet::process
  • Buffer size benchmark tool (benchmodel_bufsize)

Developed with support and sponsorship from TONE3000

João Felipe Santos and others added 3 commits February 6, 2026 09:45
Hand-optimized GEMM kernels for small matrices common in NAM models,
gated by #ifdef NAM_USE_INLINE_GEMM with Eigen fallback. Includes:

- Specialized Conv1D kernels: fused 4x4 and 2x2 kernel_size=3, plus
  fully-unrolled paths for 2x2 through 8x8 channel configurations
- Conv1x1 inline specializations for all common size combinations
- FiLM inline path with 4-element loop unrolling
- GatingActivation/BlendingActivation inline paths
- Branchless hardswish, 4-element loop unrolling for all activations
- SiLU added to LUT enable/disable
- Ring buffer refactored to Eigen block operations
- memcpy replacements for pure copy operations in wavenet
- Optimized single-channel output path in WaveNet::process
- Buffer size benchmark tool (benchmodel_bufsize)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Owner

@sdatkinson sdatkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Haven't finished reviewing conv1d, film, gating_activations, wavenet.cpp, and benchmodel]

  • One crit on comments funny business.
  • Another crit: Can you add tests to ensure that the code is correct?
  • Other nits.

// hardswish(x) = x * relu6(x + 3) / 6
// = x * clamp(x + 3, 0, 6) / 6
const float t = x + 3.0f;
const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting if this really is better? I'd be surprised if a compiler wouldn't figure out that these are the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was surprised by this too, but it does make a difference. I can share the microbenchmark if you are interested in having a look.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah I trust ya. sgtm

for (; pos + 3 < size; pos += 4)
{
// Branchless ReLU using conditional
const float v0 = data[pos], v1 = data[pos + 1];
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting...I assume that this only works better on specific chips? No way some (most?) compilers don't know to do this?

// Process 4 elements at a time: swish(x) = x * sigmoid(x) = x / (1 + exp(-x))
for (; pos + 3 < size; pos += 4)
{
const float x0 = data[pos], x1 = data[pos + 1];
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: occurs to me looking at this versus the swish(data[pos]) on the left that some inlined "swish4" could look a tad cleaner? Not sure though.

float* __restrict__ output_ptr = _output.data();
const float* __restrict__ bias_ptr = this->_bias.data();

// Specialized paths for common small channel counts
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is it worth doing everything from 1 to 8?

// Write the input data at the write position using Eigen block operations
// This is more efficient than element-by-element copy as it allows
// the compiler to vectorize the operation.
_storage.middleCols(_write_pos, num_frames).noalias() = input.leftCols(num_frames);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why I did that...is there a chance that this isn't real-time safe? I need to check the tests...

}

// Turn on fast tanh approximation
nam::activations::Activation::enable_fast_tanh();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: What do you think about making this an option?

I just got annoyed with it the other day independently of this PR :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's add that as a flag.

@@ -0,0 +1,96 @@
#include <iostream>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me (docstring?) the difference between this and benchmodel.cpp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

…ise ops

  ARM assembly analysis (-O2 -DNDEBUG) confirmed:
  - GCC auto-unrolls simple activation loops; manual 4-wide gives no benefit
  - expf() serializes sigmoid/SiLU; unrolling can't help
  - Eigen element-wise ops (.leftCols + .leftCols) produce identical codegen
    to raw float* loops when assertions are disabled

  Simplify 5 activation classes to use inline helpers (relu, sigmoid, etc.)
  and revert 3 wavenet element-wise operations back to Eigen expressions.

  Inline GEMM (Conv1x1/Conv1D), depthwise unrolling, FiLM unrolling,
  bias broadcast, and memcpy optimizations are retained — those show
  measurable wins on both desktop and Cortex-M7.

Also restored comments that were accidentally removed from wavenet.h.
@jfsantos jfsantos force-pushed the feature/inline-gemm branch from 704f309 to 7844a41 Compare February 16, 2026 18:44
Copy link
Owner

@sdatkinson sdatkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a potential huge risk with non-contiguous matrices (like when we're doing gated/blended activations).

Can you tell me if there are any "snapshot" tests that would verify that the calculations with and without this flag are the same? I'm just really concerned about something being different that I missed.

[No changes required if you can verify that the things I was concerned about are correct. Sorry; it's just too much for me to fit in my head all at once and some "simple" proof would be a big help.]

// hardswish(x) = x * relu6(x + 3) / 6
// = x * clamp(x + 3, 0, 6) / 6
const float t = x + 3.0f;
const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah I trust ya. sgtm

const int input_dim = (int)get_input_dim();
const float* __restrict__ input_ptr = input.data();
const float* __restrict__ scale_shift_ptr = scale_shift.data();
float* __restrict__ output_ptr = _output.data();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[C]

Oh, heck I just realized a potential problem with this:

What if the matrices aren't contiguous in memory?

This could happen e.g. with .topRows() since Eigen is column-major (or .leftCols() with a row-major matrix).

Is there a way to guarantee it doesn't happen? This was the problem with gating activations being wrong for a while. Is there a way we make sure that this doesn't happen?

// scale = top input_dim rows, shift = bottom input_dim rows
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * input_rows;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where you might land somewhere unexpected on f=1 on such a non-contiguous matrix.

// Validate input dimensions (assert for real-time performance)
const int total_channels = 2 * num_channels;
assert(input.rows() == total_channels);
assert(input.rows() == 2 * num_channels);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd changed my mind on these in favor of a throw that can be compiled out with #define NDEBUG

// Use the GatingActivation class
// Extract the blocks first to avoid temporary reference issues
auto input_block = this->_z.leftCols(num_frames);
auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. the non-contiguous concern.

Either this does an allocation, which means this shouldn't be real-time safe...

...but I haven't seen those tests fail, so that must mean that they address the memory as it's stored, and the activation concern I said must be actually an issue?

I really need to get to verifying that the results match the PyTorch...

const char* modelPath = argv[1];

std::cout << "Loading model " << modelPath << "\n";
std::cerr << "Usage: benchmodel <model_path> [--no-fast-tanh]\n";
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

size_t bufferSize = AUDIO_BUFFER_SIZE;
model->Reset(model->GetExpectedSampleRate(), bufferSize);
size_t numBuffers = (48000 / bufferSize) * 2;
if (model == nullptr)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why'd the change detector go nuts on this file? Is it a formatting thing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants