Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <cassert>
#include <cmath> // expf
#include <iostream> // std::cerr (kept for potential debug use)
#include <stdexcept> // std::invalid_argument
#include <functional>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -150,7 +152,7 @@ class Activation
{
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(float* data, long size) {}
virtual void apply(float* data, long size) = 0;

static Ptr get_activation(const std::string name);
static Ptr get_activation(const ActivationConfig& config);
Expand All @@ -165,13 +167,13 @@ class Activation
static std::unordered_map<std::string, Ptr> _activations;
};

// identity function activation
// identity function activation--"do nothing"
class ActivationIdentity : public nam::activations::Activation
{
public:
ActivationIdentity() = default;
~ActivationIdentity() = default;
// Inherit the default apply methods which do nothing
virtual void apply(float* data, long size) override {};
};

class ActivationTanh : public Activation
Expand Down Expand Up @@ -276,6 +278,24 @@ class ActivationPReLU : public Activation
}
ActivationPReLU(std::vector<float> ns) { negative_slopes = ns; }

void apply(float* data, long size) override
{
// Assume column-major (this is brittle)
#ifndef NDEBUG
if (size % negative_slopes.size() != 0)
{
throw std::invalid_argument("PReLU.apply(*data, size) was given an array of size " + std::to_string(size)
+ " but the activation has " + std::to_string(negative_slopes.size())
+ " channels, which doesn't divide evenly.");
}
#endif
for (long pos = 0; pos < size; pos++)
{
const float negative_slope = negative_slopes[pos % negative_slopes.size()];
data[pos] = leaky_relu(data[pos], negative_slope);
}
}

void apply(Eigen::MatrixXf& matrix) override
{
// Matrix is organized as (channels, time_steps)
Expand All @@ -285,7 +305,14 @@ class ActivationPReLU : public Activation
std::vector<float> slopes_for_channels = negative_slopes;

// Fail loudly if input has more channels than activation
assert(actual_channels == negative_slopes.size());
#ifndef NDEBUG
if (actual_channels != negative_slopes.size())
{
throw std::invalid_argument("PReLU: Received " + std::to_string(actual_channels)
+ " channels, but activation has " + std::to_string(negative_slopes.size())
+ " channels");
}
#endif

// Apply each negative slope to its corresponding channel
for (unsigned long channel = 0; channel < actual_channels; channel++)
Expand Down
5 changes: 3 additions & 2 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ int main()

test_activations::TestPReLU::test_core_function();
test_activations::TestPReLU::test_per_channel_behavior();
// This is enforced by an assert so it doesn't need to be tested
// test_activations::TestPReLU::test_wrong_number_of_channels();
test_activations::TestPReLU::test_wrong_number_of_channels_matrix();
test_activations::TestPReLU::test_wrong_size_array();
test_activations::TestPReLU::test_valid_array_size();

// Typed ActivationConfig tests
test_activations::TestTypedActivationConfig::test_simple_config();
Expand Down
61 changes: 55 additions & 6 deletions tools/test/test_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ class TestPReLU
assert(fabs(data(1, 2) - 0.0f) < 1e-6); // 0.0 (unchanged)
}

static void test_wrong_number_of_channels()
static void test_wrong_number_of_channels_matrix()
{
// Test that we fail when we have more channels than slopes
// Test that we fail when matrix has more channels than slopes
// Note: This validation only runs in debug builds (#ifndef NDEBUG)
Eigen::MatrixXf data(3, 2); // 3 channels, 2 time steps

// Initialize with test data
Expand All @@ -232,21 +233,69 @@ class TestPReLU
std::vector<float> slopes = {0.01f, 0.05f};
nam::activations::ActivationPReLU prelu(slopes);

// Apply the activation
#ifndef NDEBUG
// In debug mode, this should throw std::invalid_argument
bool caught = false;
try
{
prelu.apply(data);
}
catch (const std::runtime_error& e)
catch (const std::invalid_argument& e)
{
caught = true;
}
catch (...)
assert(caught && "Expected std::invalid_argument for channel count mismatch");
#endif
}

static void test_wrong_size_array()
{
// Test that we fail when array size doesn't divide evenly by channel count
// Note: This validation only runs in debug builds (#ifndef NDEBUG)

// Create PReLU with 2 channels
std::vector<float> slopes = {0.01f, 0.05f};
nam::activations::ActivationPReLU prelu(slopes);

// Array of size 5 doesn't divide evenly by 2 channels
std::vector<float> data = {-1.0f, -2.0f, 0.5f, 1.0f, -0.5f};

#ifndef NDEBUG
// In debug mode, this should throw std::invalid_argument
bool caught = false;
try
{
prelu.apply(data.data(), (long)data.size());
}
catch (const std::invalid_argument& e)
{
caught = true;
}
assert(caught && "Expected std::invalid_argument for array size mismatch");
#endif
}

static void test_valid_array_size()
{
// Test that valid array sizes work correctly

// Create PReLU with 2 channels
std::vector<float> slopes = {0.1f, 0.2f};
nam::activations::ActivationPReLU prelu(slopes);

// Array of size 6 divides evenly by 2 channels (3 time steps per channel)
std::vector<float> data = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};

// Should not throw
prelu.apply(data.data(), (long)data.size());

assert(caught);
// Verify results: alternating between slope 0.1 and 0.2
assert(fabs(data[0] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
assert(fabs(data[1] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
assert(fabs(data[2] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
assert(fabs(data[3] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
assert(fabs(data[4] - (-0.1f)) < 1e-6); // channel 0, slope 0.1
assert(fabs(data[5] - (-0.2f)) < 1e-6); // channel 1, slope 0.2
}
};

Expand Down