From 1cbb54a89c345d535c49d66baca203e7813ae54b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Wed, 11 Feb 2026 15:30:32 -0800 Subject: [PATCH 1/4] [FEATURE] Add compact binary model format (.namb) for fast embedded loading Introduces a binary serialization format that eliminates the need for JSON parsing (no nlohmann/json dependency in the loader), achieving 80-89% file size reduction. Supports all architectures including WaveNet with nested condition_dsp. New files: - NAM/namb_format.h: format constants, CRC32, binary reader/writer - NAM/get_dsp_namb.h/.cpp: binary loader (zero JSON dependency) - tools/nam2namb.cpp: JSON-to-binary converter CLI tool - tools/test/test_namb.cpp: round-trip, validation, and size tests --- NAM/get_dsp_namb.cpp | 465 ++++++++++++++++++++++++++++ NAM/get_dsp_namb.h | 27 ++ NAM/namb_format.h | 231 ++++++++++++++ tools/CMakeLists.txt | 1 + tools/loadmodel.cpp | 8 +- tools/nam2namb.cpp | 635 +++++++++++++++++++++++++++++++++++++++ tools/run_tests.cpp | 10 + tools/test/test_namb.cpp | 384 +++++++++++++++++++++++ 8 files changed, 1760 insertions(+), 1 deletion(-) create mode 100644 NAM/get_dsp_namb.cpp create mode 100644 NAM/get_dsp_namb.h create mode 100644 NAM/namb_format.h create mode 100644 tools/nam2namb.cpp create mode 100644 tools/test/test_namb.cpp diff --git a/NAM/get_dsp_namb.cpp b/NAM/get_dsp_namb.cpp new file mode 100644 index 0000000..459ee39 --- /dev/null +++ b/NAM/get_dsp_namb.cpp @@ -0,0 +1,465 @@ +// Binary .namb loader for NAM models +// No dependency on nlohmann/json + +#include +#include +#include +#include +#include +#include + +#include "get_dsp_namb.h" + +// Architecture headers (no json.hpp dependency needed - we only call constructors) +#include "activations.h" +#include "convnet.h" +#include "lstm.h" +#include "namb_format.h" +#include "wavenet.h" + +using namespace nam::namb; + +namespace +{ + +// ============================================================================= +// Activation config reading +// ============================================================================= + +nam::activations::ActivationConfig read_activation_config(BinaryReader& r) +{ + nam::activations::ActivationConfig config; + config.type = static_cast(r.read_u8()); + uint8_t param_count = r.read_u8(); + + switch (config.type) + { + case nam::activations::ActivationType::LeakyReLU: + if (param_count >= 1) + { + config.negative_slope = r.read_f32(); + for (uint8_t i = 1; i < param_count; i++) + r.read_f32(); // skip extra + } + break; + + case nam::activations::ActivationType::PReLU: + if (param_count == 1) + { + config.negative_slope = r.read_f32(); + } + else if (param_count > 1) + { + std::vector slopes; + slopes.reserve(param_count); + for (uint8_t i = 0; i < param_count; i++) + slopes.push_back(r.read_f32()); + config.negative_slopes = std::move(slopes); + } + break; + + case nam::activations::ActivationType::LeakyHardtanh: + if (param_count >= 4) + { + config.min_val = r.read_f32(); + config.max_val = r.read_f32(); + config.min_slope = r.read_f32(); + config.max_slope = r.read_f32(); + for (uint8_t i = 4; i < param_count; i++) + r.read_f32(); // skip extra + } + else + { + for (uint8_t i = 0; i < param_count; i++) + r.read_f32(); // skip + } + break; + + default: + // Simple activation - skip any params + for (uint8_t i = 0; i < param_count; i++) + r.read_f32(); + break; + } + + return config; +} + +// ============================================================================= +// FiLM params reading (4 bytes) +// ============================================================================= + +nam::wavenet::_FiLMParams read_film_params(BinaryReader& r) +{ + uint8_t flags = r.read_u8(); + r.read_u8(); // reserved + uint16_t groups = r.read_u16(); + + bool active = (flags & 0x01) != 0; + bool shift = (flags & 0x02) != 0; + + return nam::wavenet::_FiLMParams(active, shift, groups); +} + +// ============================================================================= +// Metadata parsing +// ============================================================================= + +struct ParsedMetadata +{ + uint8_t version_major = 0; + uint8_t version_minor = 0; + uint8_t version_patch = 0; + uint8_t meta_flags = 0; + double sample_rate = -1.0; + double loudness = 0.0; + double input_level = 0.0; + double output_level = 0.0; +}; + +ParsedMetadata read_metadata_block(BinaryReader& r) +{ + ParsedMetadata m; + m.version_major = r.read_u8(); + m.version_minor = r.read_u8(); + m.version_patch = r.read_u8(); + m.meta_flags = r.read_u8(); + m.sample_rate = r.read_f64(); + m.loudness = r.read_f64(); + m.input_level = r.read_f64(); + m.output_level = r.read_f64(); + r.skip(12); // reserved + return m; +} + +// ============================================================================= +// Model construction (recursive for condition_dsp) +// ============================================================================= + +// Result of loading a model: the DSP object and how many weights were consumed +struct LoadResult +{ + std::unique_ptr dsp; + size_t weights_consumed; +}; + +// Forward declaration +LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta); + +// --- Linear --- + +LoadResult load_linear(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +{ + int32_t receptive_field = r.read_i32(); + bool bias = r.read_u8() != 0; + int in_channels = r.read_u8(); + int out_channels = r.read_u8(); + r.read_u8(); // reserved + + std::vector w(weights, weights + weight_count); + std::unique_ptr dsp = + std::make_unique(in_channels, out_channels, receptive_field, bias, w, sample_rate); + LoadResult result; + result.dsp = std::move(dsp); + result.weights_consumed = weight_count; + return result; +} + +// --- LSTM --- + +LoadResult load_lstm(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +{ + uint16_t num_layers = r.read_u16(); + uint16_t input_size = r.read_u16(); + uint16_t hidden_size = r.read_u16(); + uint8_t in_channels = r.read_u8(); + uint8_t out_channels = r.read_u8(); + r.skip(2); // reserved + + std::vector w(weights, weights + weight_count); + std::unique_ptr dsp = std::make_unique(in_channels, out_channels, num_layers, input_size, + hidden_size, w, sample_rate); + LoadResult result; + result.dsp = std::move(dsp); + result.weights_consumed = weight_count; + return result; +} + +// --- ConvNet --- + +LoadResult load_convnet(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +{ + uint16_t channels = r.read_u16(); + bool batchnorm = r.read_u8() != 0; + uint8_t num_dilations = r.read_u8(); + uint16_t groups = r.read_u16(); + uint8_t in_channels = r.read_u8(); + uint8_t out_channels = r.read_u8(); + + nam::activations::ActivationConfig activation_config = read_activation_config(r); + + std::vector dilations; + dilations.reserve(num_dilations); + for (int i = 0; i < num_dilations; i++) + dilations.push_back(r.read_i32()); + + std::vector w(weights, weights + weight_count); + std::unique_ptr dsp = std::make_unique( + in_channels, out_channels, channels, dilations, batchnorm, activation_config, w, sample_rate, groups); + LoadResult result; + result.dsp = std::move(dsp); + result.weights_consumed = weight_count; + return result; +} + +// --- WaveNet --- + +LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta) +{ + uint8_t in_channels = r.read_u8(); + uint8_t has_head = r.read_u8(); + uint8_t num_layer_arrays = r.read_u8(); + uint8_t has_condition_dsp = r.read_u8(); + + bool with_head = (has_head != 0); + double sample_rate = meta.sample_rate; + + // Condition DSP + std::unique_ptr condition_dsp; + size_t cdsp_weights_consumed = 0; + + if (has_condition_dsp) + { + uint32_t cdsp_weight_count = r.read_u32(); + + // Read condition DSP metadata (48 bytes) + ParsedMetadata cdsp_meta = read_metadata_block(r); + + // Load condition DSP model recursively + LoadResult cdsp_result = load_model(r, weights, cdsp_weight_count, cdsp_meta); + condition_dsp = std::move(cdsp_result.dsp); + cdsp_weights_consumed = cdsp_result.weights_consumed; + + // Apply metadata to condition DSP + if (cdsp_meta.meta_flags & META_HAS_LOUDNESS) + condition_dsp->SetLoudness(cdsp_meta.loudness); + if (cdsp_meta.meta_flags & META_HAS_INPUT_LEVEL) + condition_dsp->SetInputLevel(cdsp_meta.input_level); + if (cdsp_meta.meta_flags & META_HAS_OUTPUT_LEVEL) + condition_dsp->SetOutputLevel(cdsp_meta.output_level); + condition_dsp->prewarm(); + } + + // Parse layer array params + std::vector layer_array_params; + for (int la = 0; la < num_layer_arrays; la++) + { + uint16_t input_size = r.read_u16(); + uint16_t condition_size = r.read_u16(); + uint16_t head_size = r.read_u16(); + uint16_t la_channels = r.read_u16(); + uint16_t bottleneck = r.read_u16(); + uint16_t kernel_size = r.read_u16(); + + bool head_bias = r.read_u8() != 0; + uint8_t num_dilations = r.read_u8(); + uint16_t groups_input = r.read_u16(); + uint16_t groups_input_mixin = r.read_u16(); + + // layer1x1 (4 bytes) + bool layer1x1_active = r.read_u8() != 0; + uint16_t layer1x1_groups = r.read_u16(); + r.read_u8(); // reserved + + // head1x1 (6 bytes) + bool head1x1_active = r.read_u8() != 0; + uint16_t head1x1_out_channels = r.read_u16(); + uint16_t head1x1_groups = r.read_u16(); + r.read_u8(); // reserved + + // 8 FiLM params (32 bytes) + nam::wavenet::_FiLMParams conv_pre_film = read_film_params(r); + nam::wavenet::_FiLMParams conv_post_film = read_film_params(r); + nam::wavenet::_FiLMParams input_mixin_pre_film = read_film_params(r); + nam::wavenet::_FiLMParams input_mixin_post_film = read_film_params(r); + nam::wavenet::_FiLMParams activation_pre_film = read_film_params(r); + nam::wavenet::_FiLMParams activation_post_film = read_film_params(r); + nam::wavenet::_FiLMParams layer1x1_post_film = read_film_params(r); + nam::wavenet::_FiLMParams head1x1_post_film = read_film_params(r); + + // Dilations [N * int32] + std::vector dilations; + dilations.reserve(num_dilations); + for (int i = 0; i < num_dilations; i++) + dilations.push_back(r.read_i32()); + + // Activation configs [N * variable] + std::vector activation_configs; + activation_configs.reserve(num_dilations); + for (int i = 0; i < num_dilations; i++) + activation_configs.push_back(read_activation_config(r)); + + // Gating modes [N * uint8] + std::vector gating_modes; + gating_modes.reserve(num_dilations); + for (int i = 0; i < num_dilations; i++) + { + uint8_t gm = r.read_u8(); + switch (gm) + { + case GATING_GATED: gating_modes.push_back(nam::wavenet::GatingMode::GATED); break; + case GATING_BLENDED: gating_modes.push_back(nam::wavenet::GatingMode::BLENDED); break; + default: gating_modes.push_back(nam::wavenet::GatingMode::NONE); break; + } + } + + // Secondary activation configs [N * variable] + std::vector secondary_activation_configs; + secondary_activation_configs.reserve(num_dilations); + for (int i = 0; i < num_dilations; i++) + secondary_activation_configs.push_back(read_activation_config(r)); + + nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups); + nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); + + layer_array_params.emplace_back(input_size, condition_size, head_size, la_channels, bottleneck, kernel_size, + std::move(dilations), std::move(activation_configs), std::move(gating_modes), + head_bias, groups_input, groups_input_mixin, layer1x1_params, head1x1_params, + std::move(secondary_activation_configs), conv_pre_film, conv_post_film, + input_mixin_pre_film, input_mixin_post_film, activation_pre_film, + activation_post_film, layer1x1_post_film, head1x1_post_film); + } + + // Build wavenet weights (excluding condition DSP weights which were consumed earlier) + const float* wavenet_weights_ptr = weights + cdsp_weights_consumed; + size_t wavenet_weight_count = weight_count - cdsp_weights_consumed; + std::vector wavenet_weights(wavenet_weights_ptr, wavenet_weights_ptr + wavenet_weight_count); + + // head_scale is the last weight value, but the constructor takes it as a param too + // (it gets overridden by set_weights_). Pass 0.0f; set_weights_ will set the correct value. + std::unique_ptr dsp = std::make_unique( + in_channels, layer_array_params, 0.0f, with_head, std::move(wavenet_weights), std::move(condition_dsp), + sample_rate); + + LoadResult result; + result.dsp = std::move(dsp); + result.weights_consumed = weight_count; + return result; +} + +// ============================================================================= +// Dispatch to architecture-specific loader +// ============================================================================= + +LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta) +{ + uint8_t arch = r.read_u8(); + r.read_u8(); // reserved + r.read_u16(); // config_size (not needed - we parse sequentially) + + switch (arch) + { + case ARCH_LINEAR: return load_linear(r, weights, weight_count, meta.sample_rate); + case ARCH_LSTM: return load_lstm(r, weights, weight_count, meta.sample_rate); + case ARCH_CONVNET: return load_convnet(r, weights, weight_count, meta.sample_rate); + case ARCH_WAVENET: return load_wavenet(r, weights, weight_count, meta); + default: throw std::runtime_error("NAMB: unknown architecture ID " + std::to_string(arch)); + } +} + +} // anonymous namespace + +// ============================================================================= +// Public API +// ============================================================================= + +std::unique_ptr nam::get_dsp_namb(const uint8_t* data, size_t size) +{ + if (size < FILE_HEADER_SIZE + METADATA_BLOCK_SIZE) + throw std::runtime_error("NAMB: file too small"); + + BinaryReader header_reader(data, FILE_HEADER_SIZE); + + // Validate magic + uint32_t magic = header_reader.read_u32(); + if (magic != MAGIC) + throw std::runtime_error("NAMB: invalid magic number"); + + // Validate format version + uint16_t version = header_reader.read_u16(); + if (version != FORMAT_VERSION) + throw std::runtime_error("NAMB: unsupported format version " + std::to_string(version)); + + header_reader.read_u16(); // flags + uint32_t total_file_size = header_reader.read_u32(); + uint32_t weights_offset = header_reader.read_u32(); + uint32_t total_weight_count = header_reader.read_u32(); + header_reader.read_u32(); // model_block_size + uint32_t stored_checksum = header_reader.read_u32(); + + // Validate file size + if (size < total_file_size) + throw std::runtime_error("NAMB: file truncated (expected " + std::to_string(total_file_size) + " bytes, got " + + std::to_string(size) + ")"); + + // Validate CRC32 + uint32_t computed_checksum = compute_file_crc32(data, total_file_size); + if (computed_checksum != stored_checksum) + throw std::runtime_error("NAMB: checksum mismatch"); + + // Validate weights section + size_t expected_weights_end = weights_offset + total_weight_count * sizeof(float); + if (expected_weights_end > total_file_size) + throw std::runtime_error("NAMB: weights extend beyond file"); + + // Read metadata block (at offset 32) + BinaryReader meta_reader(data + FILE_HEADER_SIZE, METADATA_BLOCK_SIZE); + ParsedMetadata meta = read_metadata_block(meta_reader); + + // Verify config version + std::string version_str = std::to_string(meta.version_major) + "." + std::to_string(meta.version_minor) + "." + + std::to_string(meta.version_patch); + nam::verify_config_version(version_str); + + // Get weight data pointer + const float* weights = reinterpret_cast(data + weights_offset); + + // Read model block (at offset 80) + size_t model_data_size = weights_offset - MODEL_BLOCK_OFFSET; + BinaryReader model_reader(data + MODEL_BLOCK_OFFSET, model_data_size); + + // Load the model + LoadResult result = load_model(model_reader, weights, total_weight_count, meta); + + // Apply metadata + if (meta.meta_flags & META_HAS_LOUDNESS) + result.dsp->SetLoudness(meta.loudness); + if (meta.meta_flags & META_HAS_INPUT_LEVEL) + result.dsp->SetInputLevel(meta.input_level); + if (meta.meta_flags & META_HAS_OUTPUT_LEVEL) + result.dsp->SetOutputLevel(meta.output_level); + + result.dsp->prewarm(); + + return std::move(result.dsp); +} + +std::unique_ptr nam::get_dsp_namb(const std::filesystem::path& filename) +{ + if (!std::filesystem::exists(filename)) + throw std::runtime_error("NAMB file doesn't exist: " + filename.string()); + + // Read entire file into memory + std::ifstream file(filename, std::ios::binary | std::ios::ate); + if (!file.is_open()) + throw std::runtime_error("Cannot open NAMB file: " + filename.string()); + + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector data(file_size); + file.read(reinterpret_cast(data.data()), file_size); + file.close(); + + return get_dsp_namb(data.data(), data.size()); +} diff --git a/NAM/get_dsp_namb.h b/NAM/get_dsp_namb.h new file mode 100644 index 0000000..c8ac27f --- /dev/null +++ b/NAM/get_dsp_namb.h @@ -0,0 +1,27 @@ +#pragma once +// Binary .namb loader for NAM models +// No dependency on nlohmann/json - suitable for embedded targets + +#include +#include +#include + +#include "dsp.h" + +namespace nam +{ + +/// \brief Load a NAM model from a .namb binary file +/// \param filename Path to the .namb file +/// \return Unique pointer to a DSP object +/// \throws std::runtime_error on format errors +std::unique_ptr get_dsp_namb(const std::filesystem::path& filename); + +/// \brief Load a NAM model from a memory buffer containing .namb data +/// \param data Pointer to the binary data +/// \param size Size of the data in bytes +/// \return Unique pointer to a DSP object +/// \throws std::runtime_error on format errors +std::unique_ptr get_dsp_namb(const uint8_t* data, size_t size); + +} // namespace nam diff --git a/NAM/namb_format.h b/NAM/namb_format.h new file mode 100644 index 0000000..f4c791d --- /dev/null +++ b/NAM/namb_format.h @@ -0,0 +1,231 @@ +#pragma once +// Compact binary model format (.namb) for NAM +// Format version 1 - no external dependencies required for reading + +#include +#include +#include +#include +#include + +namespace nam +{ +namespace namb +{ + +// Magic number: "NAMB" as little-endian uint32 +static constexpr uint32_t MAGIC = 0x4E414D42; +static constexpr uint16_t FORMAT_VERSION = 1; + +// File offsets +static constexpr size_t FILE_HEADER_SIZE = 32; +static constexpr size_t METADATA_BLOCK_SIZE = 48; +static constexpr size_t MODEL_BLOCK_OFFSET = FILE_HEADER_SIZE + METADATA_BLOCK_SIZE; // 80 + +// Architecture IDs (must match order in binary format spec) +static constexpr uint8_t ARCH_LINEAR = 0; +static constexpr uint8_t ARCH_CONVNET = 1; +static constexpr uint8_t ARCH_LSTM = 2; +static constexpr uint8_t ARCH_WAVENET = 3; + +// Metadata flags +static constexpr uint8_t META_HAS_LOUDNESS = 0x01; +static constexpr uint8_t META_HAS_INPUT_LEVEL = 0x02; +static constexpr uint8_t META_HAS_OUTPUT_LEVEL = 0x04; + +// GatingMode values (matches wavenet::GatingMode enum) +static constexpr uint8_t GATING_NONE = 0; +static constexpr uint8_t GATING_GATED = 1; +static constexpr uint8_t GATING_BLENDED = 2; + +// ============================================================================= +// CRC32 (IEEE 802.3 polynomial, same as zlib) +// ============================================================================= + +inline uint32_t crc32_table(uint8_t byte) +{ + uint32_t crc = byte; + for (int i = 0; i < 8; i++) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xEDB88320u; + else + crc >>= 1; + } + return crc; +} + +inline uint32_t crc32(const uint8_t* data, size_t size) +{ + uint32_t crc = 0xFFFFFFFFu; + for (size_t i = 0; i < size; i++) + { + crc = crc32_table((uint8_t)(crc ^ data[i])) ^ (crc >> 8); + } + return crc ^ 0xFFFFFFFFu; +} + +// CRC32 of all bytes except the checksum field (bytes 24..27) +inline uint32_t compute_file_crc32(const uint8_t* data, size_t size) +{ + // Hash bytes 0..23, then 28..end, skipping the checksum field at offset 24 + uint32_t crc = 0xFFFFFFFFu; + for (size_t i = 0; i < size; i++) + { + if (i >= 24 && i < 28) + continue; // Skip checksum field + crc = crc32_table((uint8_t)(crc ^ data[i])) ^ (crc >> 8); + } + return crc ^ 0xFFFFFFFFu; +} + +// ============================================================================= +// BinaryReader - reads from a memory buffer with bounds checking +// ============================================================================= + +class BinaryReader +{ +public: + BinaryReader(const uint8_t* data, size_t size) + : _data(data) + , _size(size) + , _pos(0) + { + } + + uint8_t read_u8() + { + check(1); + return _data[_pos++]; + } + + uint16_t read_u16() + { + check(2); + uint16_t v; + std::memcpy(&v, _data + _pos, 2); + _pos += 2; + return v; + } + + uint32_t read_u32() + { + check(4); + uint32_t v; + std::memcpy(&v, _data + _pos, 4); + _pos += 4; + return v; + } + + int32_t read_i32() + { + check(4); + int32_t v; + std::memcpy(&v, _data + _pos, 4); + _pos += 4; + return v; + } + + float read_f32() + { + check(4); + float v; + std::memcpy(&v, _data + _pos, 4); + _pos += 4; + return v; + } + + double read_f64() + { + check(8); + double v; + std::memcpy(&v, _data + _pos, 8); + _pos += 8; + return v; + } + + void skip(size_t n) + { + check(n); + _pos += n; + } + + size_t position() const { return _pos; } + size_t remaining() const { return _size - _pos; } + + const uint8_t* current_ptr() const { return _data + _pos; } + +private: + void check(size_t n) const + { + if (_pos + n > _size) + throw std::runtime_error("NAMB: unexpected end of data at offset " + std::to_string(_pos)); + } + + const uint8_t* _data; + size_t _size; + size_t _pos; +}; + +// ============================================================================= +// BinaryWriter - builds a byte buffer +// ============================================================================= + +class BinaryWriter +{ +public: + void write_u8(uint8_t v) { _data.push_back(v); } + + void write_u16(uint16_t v) + { + size_t pos = _data.size(); + _data.resize(pos + 2); + std::memcpy(_data.data() + pos, &v, 2); + } + + void write_u32(uint32_t v) + { + size_t pos = _data.size(); + _data.resize(pos + 4); + std::memcpy(_data.data() + pos, &v, 4); + } + + void write_i32(int32_t v) + { + size_t pos = _data.size(); + _data.resize(pos + 4); + std::memcpy(_data.data() + pos, &v, 4); + } + + void write_f32(float v) + { + size_t pos = _data.size(); + _data.resize(pos + 4); + std::memcpy(_data.data() + pos, &v, 4); + } + + void write_f64(double v) + { + size_t pos = _data.size(); + _data.resize(pos + 8); + std::memcpy(_data.data() + pos, &v, 8); + } + + void write_zeros(size_t n) { _data.resize(_data.size() + n, 0); } + + // Backpatch a uint32 at a specific offset + void set_u32(size_t offset, uint32_t v) { std::memcpy(_data.data() + offset, &v, 4); } + + size_t position() const { return _data.size(); } + const uint8_t* data() const { return _data.data(); } + uint8_t* data() { return _data.data(); } + size_t size() const { return _data.size(); } + + const std::vector& buffer() const { return _data; } + +private: + std::vector _data; +}; + +} // namespace namb +} // namespace nam diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8118e08..af3296e 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -12,6 +12,7 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) +add_executable(nam2namb nam2namb.cpp ${NAM_SOURCES}) add_executable(run_tests run_tests.cpp test/allocation_tracking.cpp ${NAM_SOURCES}) # Compile run_tests without optimizations to ensure allocation tracking works correctly # Also ensure assertions are enabled (NDEBUG is not defined) so tests actually run diff --git a/tools/loadmodel.cpp b/tools/loadmodel.cpp index 265139a..722518d 100644 --- a/tools/loadmodel.cpp +++ b/tools/loadmodel.cpp @@ -2,6 +2,7 @@ #include #include "NAM/dsp.h" #include "NAM/get_dsp.h" +#include "NAM/get_dsp_namb.h" int main(int argc, char* argv[]) { @@ -11,7 +12,12 @@ int main(int argc, char* argv[]) fprintf(stderr, "Loading model [%s]\n", modelPath); - auto model = nam::get_dsp(std::filesystem::path(modelPath)); + std::filesystem::path path(modelPath); + std::unique_ptr model; + if (path.extension() == ".namb") + model = nam::get_dsp_namb(path); + else + model = nam::get_dsp(path); if (model != nullptr) { diff --git a/tools/nam2namb.cpp b/tools/nam2namb.cpp new file mode 100644 index 0000000..a5508a6 --- /dev/null +++ b/tools/nam2namb.cpp @@ -0,0 +1,635 @@ +// nam2namb: Convert .nam (JSON) models to .namb (compact binary) format +// +// Usage: nam2namb input.nam [output.namb] +// If output is not specified, replaces .nam extension with .namb + +#include +#include +#include +#include +#include +#include +#include + +#include "NAM/activations.h" +#include "NAM/namb_format.h" +#include "json.hpp" + +using json = nlohmann::json; +using namespace nam::namb; + +// ============================================================================= +// Architecture name to ID mapping +// ============================================================================= + +static uint8_t architecture_id(const std::string& name) +{ + if (name == "Linear") + return ARCH_LINEAR; + if (name == "ConvNet") + return ARCH_CONVNET; + if (name == "LSTM") + return ARCH_LSTM; + if (name == "WaveNet") + return ARCH_WAVENET; + throw std::runtime_error("Unknown architecture: " + name); +} + +// ============================================================================= +// Activation config serialization +// ============================================================================= + +static void write_activation_config(BinaryWriter& w, const json& activation_json) +{ + auto config = nam::activations::ActivationConfig::from_json(activation_json); + w.write_u8(static_cast(config.type)); + + // Collect parameters + std::vector params; + if (config.type == nam::activations::ActivationType::LeakyReLU) + { + if (config.negative_slope.has_value()) + params.push_back(config.negative_slope.value()); + } + else if (config.type == nam::activations::ActivationType::PReLU) + { + if (config.negative_slopes.has_value()) + { + params = config.negative_slopes.value(); + } + else if (config.negative_slope.has_value()) + { + params.push_back(config.negative_slope.value()); + } + } + else if (config.type == nam::activations::ActivationType::LeakyHardtanh) + { + params.push_back(config.min_val.value_or(-1.0f)); + params.push_back(config.max_val.value_or(1.0f)); + params.push_back(config.min_slope.value_or(0.01f)); + params.push_back(config.max_slope.value_or(0.01f)); + } + + if (params.size() > 255) + throw std::runtime_error("Activation has too many parameters (max 255)"); + + w.write_u8(static_cast(params.size())); + for (float p : params) + w.write_f32(p); +} + +// ============================================================================= +// FiLM params serialization (4 bytes) +// ============================================================================= + +static void write_film_params(BinaryWriter& w, const json& layer_config, const std::string& key) +{ + if (layer_config.find(key) == layer_config.end() || layer_config[key] == false) + { + // Inactive FiLM + w.write_u8(0); // flags: not active + w.write_u8(0); // reserved + w.write_u16(1); // groups (default) + return; + } + + const json& film = layer_config[key]; + bool active = film.value("active", true); + bool shift = film.value("shift", true); + int groups = film.value("groups", 1); + + uint8_t flags = 0; + if (active) + flags |= 0x01; + if (shift) + flags |= 0x02; + + w.write_u8(flags); + w.write_u8(0); // reserved + w.write_u16(static_cast(groups)); +} + +// ============================================================================= +// Gating mode parsing (from JSON, same logic as wavenet::Factory) +// ============================================================================= + +static uint8_t gating_mode_from_string(const std::string& s) +{ + if (s == "gated") + return GATING_GATED; + if (s == "blended") + return GATING_BLENDED; + if (s == "none") + return GATING_NONE; + throw std::runtime_error("Invalid gating_mode: " + s); +} + +// ============================================================================= +// Metadata block serialization (48 bytes) +// ============================================================================= + +static void write_metadata_block(BinaryWriter& w, const json& model_json) +{ + // Parse version + std::string version_str = model_json["version"].get(); + int major = 0, minor = 0, patch = 0; + sscanf(version_str.c_str(), "%d.%d.%d", &major, &minor, &patch); + + w.write_u8(static_cast(major)); + w.write_u8(static_cast(minor)); + w.write_u8(static_cast(patch)); + + // Meta flags and values + uint8_t meta_flags = 0; + double loudness = 0.0, input_level = 0.0, output_level = 0.0; + + if (model_json.find("metadata") != model_json.end() && !model_json["metadata"].is_null()) + { + const json& meta = model_json["metadata"]; + if (meta.find("loudness") != meta.end() && !meta["loudness"].is_null()) + { + meta_flags |= META_HAS_LOUDNESS; + loudness = meta["loudness"].get(); + } + if (meta.find("input_level_dbu") != meta.end() && !meta["input_level_dbu"].is_null()) + { + meta_flags |= META_HAS_INPUT_LEVEL; + input_level = meta["input_level_dbu"].get(); + } + if (meta.find("output_level_dbu") != meta.end() && !meta["output_level_dbu"].is_null()) + { + meta_flags |= META_HAS_OUTPUT_LEVEL; + output_level = meta["output_level_dbu"].get(); + } + } + w.write_u8(meta_flags); + + // Sample rate + double sample_rate = -1.0; + if (model_json.find("sample_rate") != model_json.end()) + sample_rate = model_json["sample_rate"].get(); + w.write_f64(sample_rate); + + w.write_f64(loudness); + w.write_f64(input_level); + w.write_f64(output_level); + + // Reserved (12 bytes) + w.write_zeros(12); +} + +// ============================================================================= +// Collect all weights recursively (condition_dsp weights first) +// ============================================================================= + +static void collect_weights(const json& model_json, std::vector& all_weights) +{ + // If this is a WaveNet with condition_dsp, collect condition_dsp weights first + const std::string arch = model_json["architecture"].get(); + if (arch == "WaveNet") + { + const json& config = model_json["config"]; + if (config.find("condition_dsp") != config.end()) + { + const json& cdsp = config["condition_dsp"]; + collect_weights(cdsp, all_weights); + } + } + + // Then add this model's weights + if (model_json.find("weights") != model_json.end()) + { + const auto& weights = model_json["weights"]; + for (const auto& w : weights) + all_weights.push_back(w.get()); + } +} + +// ============================================================================= +// Model block serialization (recursive for condition_dsp) +// ============================================================================= + +// Forward declaration +static void write_model_block(BinaryWriter& w, const json& model_json); + +static void write_linear_config(BinaryWriter& w, const json& config) +{ + int32_t receptive_field = config["receptive_field"].get(); + bool bias = config["bias"].get(); + int in_channels = config.value("in_channels", 1); + int out_channels = config.value("out_channels", 1); + + w.write_i32(receptive_field); + w.write_u8(bias ? 1 : 0); + w.write_u8(static_cast(in_channels)); + w.write_u8(static_cast(out_channels)); + w.write_u8(0); // reserved +} + +static void write_lstm_config(BinaryWriter& w, const json& config) +{ + w.write_u16(static_cast(config["num_layers"].get())); + w.write_u16(static_cast(config["input_size"].get())); + w.write_u16(static_cast(config["hidden_size"].get())); + w.write_u8(static_cast(config.value("in_channels", 1))); + w.write_u8(static_cast(config.value("out_channels", 1))); + w.write_u16(0); // reserved +} + +static void write_convnet_config(BinaryWriter& w, const json& config) +{ + int channels = config["channels"].get(); + bool batchnorm = config["batchnorm"].get(); + const auto& dilations = config["dilations"]; + int groups = config.value("groups", 1); + int in_channels = config.value("in_channels", 1); + int out_channels = config.value("out_channels", 1); + + w.write_u16(static_cast(channels)); + w.write_u8(batchnorm ? 1 : 0); + w.write_u8(static_cast(dilations.size())); + w.write_u16(static_cast(groups)); + w.write_u8(static_cast(in_channels)); + w.write_u8(static_cast(out_channels)); + + // Activation config + write_activation_config(w, config["activation"]); + + // Dilations + for (const auto& d : dilations) + w.write_i32(d.get()); +} + +static void write_wavenet_config(BinaryWriter& w, const json& model_json) +{ + const json& config = model_json["config"]; + + int in_channels = config.value("in_channels", 1); + bool with_head = config.find("head") != config.end() && !config["head"].is_null(); + size_t num_layer_arrays = config["layers"].size(); + bool has_condition_dsp = config.find("condition_dsp") != config.end(); + + w.write_u8(static_cast(in_channels)); + w.write_u8(with_head ? 1 : 0); + w.write_u8(static_cast(num_layer_arrays)); + w.write_u8(has_condition_dsp ? 1 : 0); + + // Condition DSP (if present) + if (has_condition_dsp) + { + const json& cdsp_json = config["condition_dsp"]; + + // Count condition DSP weights (recursively) + std::vector cdsp_weights; + collect_weights(cdsp_json, cdsp_weights); + w.write_u32(static_cast(cdsp_weights.size())); + + // Condition DSP metadata (48 bytes) + write_metadata_block(w, cdsp_json); + + // Condition DSP model block (recursive) + write_model_block(w, cdsp_json); + } + + // Layer array params + for (size_t la = 0; la < num_layer_arrays; la++) + { + const json& layer = config["layers"][la]; + + int layer_channels = layer["channels"].get(); + int bottleneck = layer.value("bottleneck", layer_channels); + const auto& dilations = layer["dilations"]; + size_t num_dilations = dilations.size(); + + w.write_u16(static_cast(layer["input_size"].get())); + w.write_u16(static_cast(layer["condition_size"].get())); + w.write_u16(static_cast(layer["head_size"].get())); + w.write_u16(static_cast(layer_channels)); + w.write_u16(static_cast(bottleneck)); + w.write_u16(static_cast(layer["kernel_size"].get())); + + w.write_u8(layer["head_bias"].get() ? 1 : 0); + w.write_u8(static_cast(num_dilations)); + + int groups_input = layer.value("groups_input", 1); + int groups_input_mixin = layer.value("groups_input_mixin", 1); + w.write_u16(static_cast(groups_input)); + w.write_u16(static_cast(groups_input_mixin)); + + // layer1x1 params (4 bytes) + bool layer1x1_active = true; + int layer1x1_groups = 1; + if (layer.find("layer1x1") != layer.end()) + { + layer1x1_active = layer["layer1x1"]["active"].get(); + layer1x1_groups = layer["layer1x1"]["groups"].get(); + } + w.write_u8(layer1x1_active ? 1 : 0); + w.write_u16(static_cast(layer1x1_groups)); + w.write_u8(0); // reserved + + // head1x1 params (6 bytes) + bool head1x1_active = false; + int head1x1_out_channels = layer_channels; + int head1x1_groups = 1; + if (layer.find("head1x1") != layer.end()) + { + head1x1_active = layer["head1x1"]["active"].get(); + head1x1_out_channels = layer["head1x1"]["out_channels"].get(); + head1x1_groups = layer["head1x1"]["groups"].get(); + } + w.write_u8(head1x1_active ? 1 : 0); + w.write_u16(static_cast(head1x1_out_channels)); + w.write_u16(static_cast(head1x1_groups)); + w.write_u8(0); // reserved + + // 8 FiLM params (32 bytes) + write_film_params(w, layer, "conv_pre_film"); + write_film_params(w, layer, "conv_post_film"); + write_film_params(w, layer, "input_mixin_pre_film"); + write_film_params(w, layer, "input_mixin_post_film"); + write_film_params(w, layer, "activation_pre_film"); + write_film_params(w, layer, "activation_post_film"); + write_film_params(w, layer, "layer1x1_post_film"); + write_film_params(w, layer, "head1x1_post_film"); + + // Dilations [num_dilations * int32] + for (const auto& d : dilations) + w.write_i32(d.get()); + + // Activation configs [num_dilations * variable] + if (layer["activation"].is_array()) + { + for (const auto& act : layer["activation"]) + write_activation_config(w, act); + } + else + { + // Single activation - write it N times + for (size_t i = 0; i < num_dilations; i++) + write_activation_config(w, layer["activation"]); + } + + // Gating modes [num_dilations * uint8] + if (layer.find("gating_mode") != layer.end()) + { + if (layer["gating_mode"].is_array()) + { + for (const auto& gm : layer["gating_mode"]) + w.write_u8(gating_mode_from_string(gm.get())); + } + else + { + uint8_t mode = gating_mode_from_string(layer["gating_mode"].get()); + for (size_t i = 0; i < num_dilations; i++) + w.write_u8(mode); + } + } + else if (layer.find("gated") != layer.end()) + { + // Backward compatibility + uint8_t mode = layer["gated"].get() ? GATING_GATED : GATING_NONE; + for (size_t i = 0; i < num_dilations; i++) + w.write_u8(mode); + } + else + { + for (size_t i = 0; i < num_dilations; i++) + w.write_u8(GATING_NONE); + } + + // Secondary activation configs [num_dilations * variable] + // Parse gating modes to determine which layers need secondary activations + std::vector gating_modes; + if (layer.find("gating_mode") != layer.end()) + { + if (layer["gating_mode"].is_array()) + { + for (const auto& gm : layer["gating_mode"]) + gating_modes.push_back(gating_mode_from_string(gm.get())); + } + else + { + uint8_t mode = gating_mode_from_string(layer["gating_mode"].get()); + gating_modes.resize(num_dilations, mode); + } + } + else if (layer.find("gated") != layer.end()) + { + uint8_t mode = layer["gated"].get() ? GATING_GATED : GATING_NONE; + gating_modes.resize(num_dilations, mode); + } + else + { + gating_modes.resize(num_dilations, GATING_NONE); + } + + for (size_t i = 0; i < num_dilations; i++) + { + if (gating_modes[i] != GATING_NONE) + { + // Need a secondary activation + if (layer.find("secondary_activation") != layer.end()) + { + if (layer["secondary_activation"].is_array()) + { + write_activation_config(w, layer["secondary_activation"][i]); + } + else + { + write_activation_config(w, layer["secondary_activation"]); + } + } + else + { + // Default: Sigmoid + w.write_u8(static_cast(nam::activations::ActivationType::Sigmoid)); + w.write_u8(0); // no params + } + } + else + { + // NONE mode - write an empty activation config (identity) + w.write_u8(static_cast(nam::activations::ActivationType::Tanh)); // type doesn't matter + w.write_u8(0); // no params + } + } + } +} + +static void write_model_block(BinaryWriter& w, const json& model_json) +{ + std::string arch_name = model_json["architecture"].get(); + uint8_t arch = architecture_id(arch_name); + + // Write model block header + w.write_u8(arch); + w.write_u8(0); // reserved + + // Placeholder for config_size (will backpatch) + size_t config_size_offset = w.position(); + w.write_u16(0); + + size_t config_start = w.position(); + + const json& config = model_json["config"]; + + switch (arch) + { + case ARCH_LINEAR: write_linear_config(w, config); break; + case ARCH_CONVNET: write_convnet_config(w, config); break; + case ARCH_LSTM: write_lstm_config(w, config); break; + case ARCH_WAVENET: write_wavenet_config(w, model_json); break; + default: throw std::runtime_error("Unknown architecture ID"); + } + + // Backpatch config_size (uint16 at config_size_offset) + size_t config_size = w.position() - config_start; + if (config_size > 65535) + throw std::runtime_error("Config too large for uint16"); + uint16_t cs = static_cast(config_size); + std::memcpy(w.data() + config_size_offset, &cs, 2); +} + +// ============================================================================= +// Main conversion function +// ============================================================================= + +static std::vector convert_nam_to_namb(const json& model_json) +{ + BinaryWriter w; + + // ---- File Header (32 bytes) ---- + w.write_u32(MAGIC); + w.write_u16(FORMAT_VERSION); + w.write_u16(0); // flags + + size_t total_file_size_offset = w.position(); + w.write_u32(0); // total_file_size (backpatch) + + size_t weights_offset_pos = w.position(); + w.write_u32(0); // weights_offset (backpatch) + + size_t total_weight_count_offset = w.position(); + w.write_u32(0); // total_weight_count (backpatch) + + size_t model_block_size_offset = w.position(); + w.write_u32(0); // model_block_size (backpatch) + + w.write_u32(0); // checksum (backpatch) + w.write_u32(0); // reserved + + // ---- Metadata Block (48 bytes at offset 32) ---- + write_metadata_block(w, model_json); + + // ---- Model Block (variable, at offset 80) ---- + size_t model_block_start = w.position(); + write_model_block(w, model_json); + size_t model_block_size = w.position() - model_block_start; + + // ---- Padding to align weights to 4 bytes ---- + while (w.position() % 4 != 0) + w.write_u8(0); + + size_t weights_offset = w.position(); + + // ---- Weight Data ---- + std::vector all_weights; + collect_weights(model_json, all_weights); + + for (float wt : all_weights) + w.write_f32(wt); + + // ---- Backpatch header fields ---- + uint32_t total_file_size = static_cast(w.size()); + w.set_u32(total_file_size_offset, total_file_size); + w.set_u32(weights_offset_pos, static_cast(weights_offset)); + w.set_u32(total_weight_count_offset, static_cast(all_weights.size())); + w.set_u32(model_block_size_offset, static_cast(model_block_size)); + + // ---- Compute and write CRC32 ---- + uint32_t checksum = compute_file_crc32(w.data(), w.size()); + w.set_u32(24, checksum); // checksum at offset 24 + + return w.buffer(); +} + +// ============================================================================= +// Entry point +// ============================================================================= + +int main(int argc, char* argv[]) +{ + if (argc < 2) + { + std::cerr << "Usage: nam2namb input.nam [output.namb]" << std::endl; + return 1; + } + + std::filesystem::path input_path(argv[1]); + std::filesystem::path output_path; + + if (argc >= 3) + { + output_path = argv[2]; + } + else + { + output_path = input_path; + output_path.replace_extension(".namb"); + } + + // Read input JSON + std::ifstream input_file(input_path); + if (!input_file.is_open()) + { + std::cerr << "Error: cannot open " << input_path << std::endl; + return 1; + } + + json model_json; + try + { + input_file >> model_json; + } + catch (const std::exception& e) + { + std::cerr << "Error parsing JSON: " << e.what() << std::endl; + return 1; + } + input_file.close(); + + // Convert + std::vector namb_data; + try + { + namb_data = convert_nam_to_namb(model_json); + } + catch (const std::exception& e) + { + std::cerr << "Error converting: " << e.what() << std::endl; + return 1; + } + + // Write output + std::ofstream output_file(output_path, std::ios::binary); + if (!output_file.is_open()) + { + std::cerr << "Error: cannot create " << output_path << std::endl; + return 1; + } + output_file.write(reinterpret_cast(namb_data.data()), namb_data.size()); + output_file.close(); + + // Report + size_t json_size = std::filesystem::file_size(input_path); + size_t namb_size = namb_data.size(); + double reduction = 100.0 * (1.0 - (double)namb_size / (double)json_size); + + std::cout << input_path.filename().string() << " -> " << output_path.filename().string() << std::endl; + std::cout << " JSON: " << json_size << " bytes" << std::endl; + std::cout << " NAMB: " << namb_size << " bytes" << std::endl; + std::cout << " Reduction: " << std::fixed << std::setprecision(1) << reduction << "%" << std::endl; + + return 0; +} diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 9b3bdec..3daad9a 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -25,6 +25,7 @@ #include "test/test_input_buffer_verification.cpp" #include "test/test_lstm.cpp" #include "test/test_wavenet_configurable_gating.cpp" +#include "test/test_namb.cpp" int main() { @@ -243,6 +244,15 @@ int main() test_get_dsp::test_version_minor_one_beyond_supported(); test_get_dsp::test_version_too_early(); + // Binary format (.namb) tests + test_namb::test_crc32(); + test_namb::test_bad_magic(); + test_namb::test_truncated_file(); + test_namb::test_wrong_version(); + test_namb::test_bad_checksum(); + test_namb::test_roundtrip(); + test_namb::test_size_reduction(); + // Finally, some end-to-end tests. test_get_dsp::test_load_and_process_nam_files(); diff --git a/tools/test/test_namb.cpp b/tools/test/test_namb.cpp new file mode 100644 index 0000000..88c36cb --- /dev/null +++ b/tools/test/test_namb.cpp @@ -0,0 +1,384 @@ +// Tests for .namb binary format +// - Round-trip: JSON -> NAMB -> load -> process and compare outputs +// - Format validation: bad magic, truncated, wrong version +// - Size verification: NAMB < NAM for all example models + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NAM/dsp.h" +#include "NAM/get_dsp.h" +#include "NAM/get_dsp_namb.h" +#include "NAM/namb_format.h" +#include "json.hpp" + +namespace test_namb +{ + +// ============================================================================= +// Helper: convert a .nam JSON to .namb in memory +// ============================================================================= + +// We need access to the converter logic. Rather than including nam2namb.cpp directly, +// we replicate the conversion by reading JSON then writing .namb via the public tool. +// For simplicity, we use a temp file approach: write .nam, run conversion, read .namb. +// But actually, let's use the public APIs directly: load JSON via get_dsp(), load NAMB via get_dsp_namb(). + +// Helper: read file into byte vector +static std::vector read_file_bytes(const std::filesystem::path& path) +{ + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f.is_open()) + throw std::runtime_error("Cannot open: " + path.string()); + size_t size = f.tellg(); + f.seekg(0, std::ios::beg); + std::vector data(size); + f.read(reinterpret_cast(data.data()), size); + return data; +} + +// Helper: process audio through a DSP model and return outputs +static std::vector> process_model(nam::DSP* dsp, int num_buffers, int buffer_size) +{ + const int in_channels = dsp->NumInputChannels(); + const int out_channels = dsp->NumOutputChannels(); + const double sample_rate = dsp->GetExpectedSampleRate() > 0 ? dsp->GetExpectedSampleRate() : 48000.0; + + dsp->Reset(sample_rate, buffer_size); + + std::vector> inputBuffers(in_channels); + std::vector> outputBuffers(out_channels); + std::vector inputPtrs(in_channels); + std::vector outputPtrs(out_channels); + + for (int ch = 0; ch < in_channels; ch++) + { + inputBuffers[ch].resize(buffer_size, 0.0); + inputPtrs[ch] = inputBuffers[ch].data(); + } + for (int ch = 0; ch < out_channels; ch++) + { + outputBuffers[ch].resize(buffer_size, 0.0); + outputPtrs[ch] = outputBuffers[ch].data(); + } + + // Collect all output samples + std::vector> all_outputs(out_channels); + + for (int buf = 0; buf < num_buffers; buf++) + { + // Fill with deterministic test data + for (int ch = 0; ch < in_channels; ch++) + { + for (int i = 0; i < buffer_size; i++) + { + inputBuffers[ch][i] = (NAM_SAMPLE)(0.1 * (ch + 1) * ((buf * buffer_size + i) % 100) / 100.0); + } + } + + dsp->process(inputPtrs.data(), outputPtrs.data(), buffer_size); + + for (int ch = 0; ch < out_channels; ch++) + { + for (int i = 0; i < buffer_size; i++) + { + all_outputs[ch].push_back((double)outputBuffers[ch][i]); + } + } + } + + return all_outputs; +} + +// ============================================================================= +// Round-trip test: For a .nam file, convert to .namb and compare outputs +// ============================================================================= + +static void test_roundtrip_for_file(const std::string& nam_path) +{ + std::filesystem::path model_path(nam_path); + if (!std::filesystem::exists(model_path)) + { + std::cerr << " Skipping (not found): " << nam_path << std::endl; + return; + } + + std::cout << " Testing round-trip: " << model_path.filename().string() << std::endl; + + // Load JSON model + std::unique_ptr json_model = nam::get_dsp(model_path); + assert(json_model != nullptr); + + // Process with JSON model + const int num_buffers = 5; + const int buffer_size = 64; + auto json_outputs = process_model(json_model.get(), num_buffers, buffer_size); + + // Convert to .namb: read JSON, use nam2namb logic + // We need to create a .namb file. Use a temp path. + std::filesystem::path namb_path = model_path; + namb_path.replace_extension(".namb"); + + // Use the nam2namb tool to create the .namb file + // Since we can't easily call the tool's function, we'll use system() or + // construct the binary ourselves. For the test, let's shell out. + std::string cmd = "./build/tools/nam2namb " + model_path.string() + " " + namb_path.string() + " 2>&1"; + int ret = system(cmd.c_str()); + if (ret != 0) + { + // Try relative path from where tests might be run + cmd = "nam2namb " + model_path.string() + " " + namb_path.string() + " 2>&1"; + ret = system(cmd.c_str()); + } + assert(ret == 0); + assert(std::filesystem::exists(namb_path)); + + // Load .namb model + std::unique_ptr namb_model = nam::get_dsp_namb(namb_path); + assert(namb_model != nullptr); + + // Verify same channel counts + assert(json_model->NumInputChannels() == namb_model->NumInputChannels()); + assert(json_model->NumOutputChannels() == namb_model->NumOutputChannels()); + + // Verify same metadata + assert(json_model->HasLoudness() == namb_model->HasLoudness()); + if (json_model->HasLoudness()) + { + assert(json_model->GetLoudness() == namb_model->GetLoudness()); + } + + // Process with NAMB model + auto namb_outputs = process_model(namb_model.get(), num_buffers, buffer_size); + + // Compare outputs - should be bit-identical + assert(json_outputs.size() == namb_outputs.size()); + for (size_t ch = 0; ch < json_outputs.size(); ch++) + { + assert(json_outputs[ch].size() == namb_outputs[ch].size()); + for (size_t i = 0; i < json_outputs[ch].size(); i++) + { + if (json_outputs[ch][i] != namb_outputs[ch][i]) + { + std::cerr << " Output mismatch at ch=" << ch << " sample=" << i << ": JSON=" << json_outputs[ch][i] + << " NAMB=" << namb_outputs[ch][i] << std::endl; + assert(false); + } + } + } + + // Clean up temp file + std::filesystem::remove(namb_path); + + std::cout << " PASS" << std::endl; +} + +void test_roundtrip() +{ + std::cout << "test_namb::test_roundtrip" << std::endl; + + // Test all available example models + const std::vector models = {"example_models/wavenet.nam", "example_models/lstm.nam", + "example_models/wavenet_condition_dsp.nam", + "example_models/wavenet_a2_max.nam"}; + + for (const auto& model : models) + { + test_roundtrip_for_file(model); + } +} + +// ============================================================================= +// Format validation tests +// ============================================================================= + +void test_bad_magic() +{ + std::cout << "test_namb::test_bad_magic" << std::endl; + + // Create minimal data with wrong magic + std::vector data(128, 0); + data[0] = 'X'; // Wrong magic + + bool threw = false; + try + { + nam::get_dsp_namb(data.data(), data.size()); + } + catch (const std::runtime_error& e) + { + threw = true; + assert(std::string(e.what()).find("magic") != std::string::npos); + } + assert(threw); + std::cout << " PASS" << std::endl; +} + +void test_truncated_file() +{ + std::cout << "test_namb::test_truncated_file" << std::endl; + + // File too small for header + std::vector data(16, 0); + // Set magic correctly + uint32_t magic = nam::namb::MAGIC; + std::memcpy(data.data(), &magic, 4); + + bool threw = false; + try + { + nam::get_dsp_namb(data.data(), data.size()); + } + catch (const std::runtime_error&) + { + threw = true; + } + assert(threw); + std::cout << " PASS" << std::endl; +} + +void test_wrong_version() +{ + std::cout << "test_namb::test_wrong_version" << std::endl; + + // Create data with wrong format version + std::vector data(128, 0); + uint32_t magic = nam::namb::MAGIC; + std::memcpy(data.data(), &magic, 4); + uint16_t bad_version = 99; + std::memcpy(data.data() + 4, &bad_version, 2); + + bool threw = false; + try + { + nam::get_dsp_namb(data.data(), data.size()); + } + catch (const std::runtime_error& e) + { + threw = true; + assert(std::string(e.what()).find("version") != std::string::npos); + } + assert(threw); + std::cout << " PASS" << std::endl; +} + +void test_bad_checksum() +{ + std::cout << "test_namb::test_bad_checksum" << std::endl; + + // First create a valid .namb file, then corrupt it + std::filesystem::path nam_path("example_models/lstm.nam"); + if (!std::filesystem::exists(nam_path)) + { + std::cerr << " Skipping (lstm.nam not found)" << std::endl; + return; + } + + std::filesystem::path namb_path("example_models/lstm_test_bad_crc.namb"); + std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " 2>&1"; + int ret = system(cmd.c_str()); + if (ret != 0) + { + std::cerr << " Skipping (nam2namb not available)" << std::endl; + return; + } + + // Read the .namb file + auto data = read_file_bytes(namb_path); + + // Corrupt a byte in the weight data (after the checksum) + if (data.size() > 100) + { + data[data.size() - 1] ^= 0xFF; + } + + bool threw = false; + try + { + nam::get_dsp_namb(data.data(), data.size()); + } + catch (const std::runtime_error& e) + { + threw = true; + assert(std::string(e.what()).find("checksum") != std::string::npos); + } + assert(threw); + + std::filesystem::remove(namb_path); + std::cout << " PASS" << std::endl; +} + +// ============================================================================= +// Size comparison test +// ============================================================================= + +void test_size_reduction() +{ + std::cout << "test_namb::test_size_reduction" << std::endl; + + const std::vector models = {"example_models/wavenet.nam", "example_models/lstm.nam", + "example_models/wavenet_condition_dsp.nam", + "example_models/wavenet_a2_max.nam"}; + + for (const auto& nam_path_str : models) + { + std::filesystem::path nam_path(nam_path_str); + if (!std::filesystem::exists(nam_path)) + continue; + + std::filesystem::path namb_path = nam_path; + namb_path.replace_extension(".namb"); + + std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " 2>&1"; + int ret = system(cmd.c_str()); + if (ret != 0) + continue; + + size_t nam_size = std::filesystem::file_size(nam_path); + size_t namb_size = std::filesystem::file_size(namb_path); + double reduction = 100.0 * (1.0 - (double)namb_size / (double)nam_size); + + std::cout << " " << nam_path.filename().string() << ": " << nam_size << " -> " << namb_size << " (" + << std::fixed << std::setprecision(1) << reduction << "% reduction)" << std::endl; + + // .namb should always be smaller than .nam + assert(namb_size < nam_size); + + // Should be at least 50% reduction (typically ~85%) + assert(reduction > 50.0); + + std::filesystem::remove(namb_path); + } + std::cout << " PASS" << std::endl; +} + +// ============================================================================= +// CRC32 test +// ============================================================================= + +void test_crc32() +{ + std::cout << "test_namb::test_crc32" << std::endl; + + // Test known CRC32 values + const uint8_t test1[] = "123456789"; + uint32_t crc1 = nam::namb::crc32(test1, 9); + // CRC32 of "123456789" is 0xCBF43926 + assert(crc1 == 0xCBF43926u); + + // Empty data + uint32_t crc_empty = nam::namb::crc32(nullptr, 0); + assert(crc_empty == 0x00000000u); + + std::cout << " PASS" << std::endl; +} + +}; // namespace test_namb From d3487ad3f184d9b06b26a5578092291136221f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 10:48:30 -0800 Subject: [PATCH 2/4] [REFACTOR] Unify JSON and binary model loader construction paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce typed config structs per architecture (LinearConfig, LSTMConfig, ConvNetConfig, WaveNetConfig) and a single create_dsp() function that both the JSON and .namb binary loaders feed into. This eliminates duplicated construction logic — adding a new architecture now only requires a format- specific parser, not separate factory code in each loader. - Add config structs and parse_config_json() to each architecture header/impl - Add NAM/model_config.h with ModelConfig variant, ModelMetadata, create_dsp() - Refactor get_dsp(dspData&) to use parse_model_config_json() → create_dsp() - Refactor get_dsp_namb.cpp load_*() to return typed configs → create_dsp() - Register missing Linear factory in FactoryRegistry - Silence test_namb.cpp output on success to match other tests --- NAM/convnet.cpp | 29 ++++--- NAM/convnet.h | 17 ++++ NAM/dsp.cpp | 26 ++++-- NAM/dsp.h | 15 ++++ NAM/get_dsp.cpp | 129 ++++++++++++++++++---------- NAM/get_dsp_namb.cpp | 181 +++++++++++++++++---------------------- NAM/lstm.cpp | 23 +++-- NAM/lstm.h | 15 ++++ NAM/model_config.h | 51 +++++++++++ NAM/wavenet.cpp | 75 ++++++++-------- NAM/wavenet.h | 23 +++++ tools/test/test_namb.cpp | 46 +--------- 12 files changed, 377 insertions(+), 253 deletions(-) create mode 100644 NAM/model_config.h diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index fc7c151..3be4ca6 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -322,22 +322,27 @@ void nam::convnet::ConvNet::_rewind_buffers_() this->Buffer::_rewind_buffers_(); } +// Config parser +nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config) +{ + ConvNetConfig c; + c.channels = config["channels"]; + c.dilations = config["dilations"].get>(); + c.batchnorm = config["batchnorm"]; + c.activation = activations::ActivationConfig::from_json(config["activation"]); + c.groups = config.value("groups", 1); + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory std::unique_ptr nam::convnet::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int channels = config["channels"]; - const std::vector dilations = config["dilations"]; - const bool batchnorm = config["batchnorm"]; - // Parse JSON into typed ActivationConfig at model loading boundary - const activations::ActivationConfig activation_config = - activations::ActivationConfig::from_json(config["activation"]); - const int groups = config.value("groups", 1); // defaults to 1 - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.channels, c.dilations, c.batchnorm, + c.activation, weights, expectedSampleRate, c.groups); } namespace diff --git a/NAM/convnet.h b/NAM/convnet.h index 0d963df..c1a961d 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -165,6 +165,23 @@ class ConvNet : public Buffer int PrewarmSamples() override { return mPrewarmSamples; }; }; +/// \brief Configuration for a ConvNet model +struct ConvNetConfig +{ + int channels; + std::vector dilations; + bool batchnorm; + activations::ActivationConfig activation; + int groups; + int in_channels; + int out_channels; +}; + +/// \brief Parse ConvNet configuration from JSON +/// \param config JSON configuration object +/// \return ConvNetConfig +ConvNetConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate ConvNet from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 05dab09..8bc4c3e 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -300,16 +300,30 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num nam::Buffer::_advance_input_buffer_(num_frames); } +// Config parser +nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config) +{ + LinearConfig c; + c.receptive_field = config["receptive_field"]; + c.bias = config["bias"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory std::unique_ptr nam::linear::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int receptive_field = config["receptive_field"]; - const bool bias = config["bias"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.receptive_field, c.bias, weights, + expectedSampleRate); +} + +// Register the factory +namespace +{ +static nam::factory::Helper _register_Linear("Linear", nam::linear::Factory); } // NN modules ================================================================= diff --git a/NAM/dsp.h b/NAM/dsp.h index 1313ad9..7c47d38 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -258,6 +258,21 @@ class Linear : public Buffer namespace linear { + +/// \brief Configuration for a Linear model +struct LinearConfig +{ + int receptive_field; + bool bias; + int in_channels; + int out_channels; +}; + +/// \brief Parse Linear configuration from JSON +/// \param config JSON configuration object +/// \return LinearConfig +LinearConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate Linear model from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 57d0fbd..3ee5606 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "dsp.h" #include "registry.h" @@ -11,6 +12,7 @@ #include "convnet.h" #include "wavenet.h" #include "get_dsp.h" +#include "model_config.h" namespace nam { @@ -146,62 +148,103 @@ std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConf return get_dsp(conf); } -struct OptionalValue +// ============================================================================= +// Unified construction path +// ============================================================================= + +ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config, double sample_rate) { - bool have = false; - double value = 0.0; -}; + if (architecture == "Linear") + return linear::parse_config_json(config); + else if (architecture == "LSTM") + return lstm::parse_config_json(config); + else if (architecture == "ConvNet") + return convnet::parse_config_json(config); + else if (architecture == "WaveNet") + return wavenet::parse_config_json(config, sample_rate); + else + throw std::runtime_error("Unknown architecture: " + architecture); +} -std::unique_ptr get_dsp(dspData& conf) +namespace { - verify_config_version(conf.version); - auto& architecture = conf.architecture; - nlohmann::json& config = conf.config; - std::vector& weights = conf.weights; - OptionalValue loudness, inputLevel, outputLevel; +void apply_metadata(DSP& dsp, const ModelMetadata& metadata) +{ + if (metadata.loudness.has_value()) + dsp.SetLoudness(metadata.loudness.value()); + if (metadata.input_level.has_value()) + dsp.SetInputLevel(metadata.input_level.value()); + if (metadata.output_level.has_value()) + dsp.SetOutputLevel(metadata.output_level.value()); +} + +} // anonymous namespace + +std::unique_ptr create_dsp(ModelConfig config, std::vector weights, const ModelMetadata& metadata) +{ + const double sample_rate = metadata.sample_rate; - auto AssignOptional = [&conf](const std::string key, OptionalValue& v) { - if (conf.metadata.find(key) != conf.metadata.end()) - { - if (!conf.metadata[key].is_null()) + std::unique_ptr out = std::visit( + [&](auto&& cfg) -> std::unique_ptr { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.receptive_field, cfg.bias, weights, + sample_rate); + } + else if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.num_layers, cfg.input_size, + cfg.hidden_size, weights, sample_rate); + } + else if constexpr (std::is_same_v) { - v.value = conf.metadata[key]; - v.have = true; + return std::make_unique(cfg.in_channels, cfg.out_channels, cfg.channels, cfg.dilations, + cfg.batchnorm, cfg.activation, weights, sample_rate, cfg.groups); } - } - }; + else if constexpr (std::is_same_v) + { + return std::make_unique(cfg.in_channels, cfg.layer_array_params, cfg.head_scale, + cfg.with_head, std::move(weights), std::move(cfg.condition_dsp), + sample_rate); + } + }, + std::move(config)); - if (!conf.metadata.is_null()) - { - AssignOptional("loudness", loudness); - AssignOptional("input_level_dbu", inputLevel); - AssignOptional("output_level_dbu", outputLevel); - } - const double expectedSampleRate = conf.expected_sample_rate; + apply_metadata(*out, metadata); + // FIXME should we remove prewarming from model load? + out->prewarm(); + return out; +} - // Initialize using registry-based factory - std::unique_ptr out = - nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate); +// ============================================================================= +// get_dsp(dspData&) — now uses unified path +// ============================================================================= - if (loudness.have) - { - out->SetLoudness(loudness.value); - } - if (inputLevel.have) - { - out->SetInputLevel(inputLevel.value); - } - if (outputLevel.have) +std::unique_ptr get_dsp(dspData& conf) +{ + verify_config_version(conf.version); + + // Extract metadata from JSON + ModelMetadata metadata; + metadata.version = conf.version; + metadata.sample_rate = conf.expected_sample_rate; + + if (!conf.metadata.is_null()) { - out->SetOutputLevel(outputLevel.value); + auto extract = [&conf](const std::string& key) -> std::optional { + if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null()) + return conf.metadata[key].get(); + return std::nullopt; + }; + metadata.loudness = extract("loudness"); + metadata.input_level = extract("input_level_dbu"); + metadata.output_level = extract("output_level_dbu"); } - // "pre-warm" the model to settle initial conditions - // Can this be removed now that it's part of Reset()? - out->prewarm(); - - return out; + ModelConfig model_config = parse_model_config_json(conf.architecture, conf.config, conf.expected_sample_rate); + return create_dsp(std::move(model_config), std::move(conf.weights), metadata); } double get_sample_rate_from_nam_file(const nlohmann::json& j) diff --git a/NAM/get_dsp_namb.cpp b/NAM/get_dsp_namb.cpp index 459ee39..37c1a5c 100644 --- a/NAM/get_dsp_namb.cpp +++ b/NAM/get_dsp_namb.cpp @@ -1,5 +1,5 @@ // Binary .namb loader for NAM models -// No dependency on nlohmann/json +// Uses the unified create_dsp() path shared with the JSON loader #include #include @@ -10,12 +10,9 @@ #include "get_dsp_namb.h" -// Architecture headers (no json.hpp dependency needed - we only call constructors) #include "activations.h" -#include "convnet.h" -#include "lstm.h" +#include "model_config.h" #include "namb_format.h" -#include "wavenet.h" using namespace nam::namb; @@ -132,100 +129,106 @@ ParsedMetadata read_metadata_block(BinaryReader& r) return m; } +nam::ModelMetadata to_model_metadata(const ParsedMetadata& pm) +{ + nam::ModelMetadata meta; + meta.version = std::to_string(pm.version_major) + "." + std::to_string(pm.version_minor) + "." + + std::to_string(pm.version_patch); + meta.sample_rate = pm.sample_rate; + if (pm.meta_flags & META_HAS_LOUDNESS) + meta.loudness = pm.loudness; + if (pm.meta_flags & META_HAS_INPUT_LEVEL) + meta.input_level = pm.input_level; + if (pm.meta_flags & META_HAS_OUTPUT_LEVEL) + meta.output_level = pm.output_level; + return meta; +} + // ============================================================================= -// Model construction (recursive for condition_dsp) +// Binary parsing into typed configs // ============================================================================= -// Result of loading a model: the DSP object and how many weights were consumed struct LoadResult { - std::unique_ptr dsp; - size_t weights_consumed; + nam::ModelConfig config; + std::vector weights; }; // Forward declaration -LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta); +LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const nam::ModelMetadata& meta); // --- Linear --- -LoadResult load_linear(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +LoadResult load_linear(BinaryReader& r, const float* weights, size_t weight_count) { - int32_t receptive_field = r.read_i32(); - bool bias = r.read_u8() != 0; - int in_channels = r.read_u8(); - int out_channels = r.read_u8(); + nam::linear::LinearConfig cfg; + cfg.receptive_field = r.read_i32(); + cfg.bias = r.read_u8() != 0; + cfg.in_channels = r.read_u8(); + cfg.out_channels = r.read_u8(); r.read_u8(); // reserved - std::vector w(weights, weights + weight_count); - std::unique_ptr dsp = - std::make_unique(in_channels, out_channels, receptive_field, bias, w, sample_rate); LoadResult result; - result.dsp = std::move(dsp); - result.weights_consumed = weight_count; + result.config = std::move(cfg); + result.weights.assign(weights, weights + weight_count); return result; } // --- LSTM --- -LoadResult load_lstm(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +LoadResult load_lstm(BinaryReader& r, const float* weights, size_t weight_count) { - uint16_t num_layers = r.read_u16(); - uint16_t input_size = r.read_u16(); - uint16_t hidden_size = r.read_u16(); - uint8_t in_channels = r.read_u8(); - uint8_t out_channels = r.read_u8(); + nam::lstm::LSTMConfig cfg; + cfg.num_layers = r.read_u16(); + cfg.input_size = r.read_u16(); + cfg.hidden_size = r.read_u16(); + cfg.in_channels = r.read_u8(); + cfg.out_channels = r.read_u8(); r.skip(2); // reserved - std::vector w(weights, weights + weight_count); - std::unique_ptr dsp = std::make_unique(in_channels, out_channels, num_layers, input_size, - hidden_size, w, sample_rate); LoadResult result; - result.dsp = std::move(dsp); - result.weights_consumed = weight_count; + result.config = std::move(cfg); + result.weights.assign(weights, weights + weight_count); return result; } // --- ConvNet --- -LoadResult load_convnet(BinaryReader& r, const float* weights, size_t weight_count, double sample_rate) +LoadResult load_convnet(BinaryReader& r, const float* weights, size_t weight_count) { - uint16_t channels = r.read_u16(); - bool batchnorm = r.read_u8() != 0; + nam::convnet::ConvNetConfig cfg; + cfg.channels = r.read_u16(); + cfg.batchnorm = r.read_u8() != 0; uint8_t num_dilations = r.read_u8(); - uint16_t groups = r.read_u16(); - uint8_t in_channels = r.read_u8(); - uint8_t out_channels = r.read_u8(); + cfg.groups = r.read_u16(); + cfg.in_channels = r.read_u8(); + cfg.out_channels = r.read_u8(); - nam::activations::ActivationConfig activation_config = read_activation_config(r); + cfg.activation = read_activation_config(r); - std::vector dilations; - dilations.reserve(num_dilations); + cfg.dilations.reserve(num_dilations); for (int i = 0; i < num_dilations; i++) - dilations.push_back(r.read_i32()); + cfg.dilations.push_back(r.read_i32()); - std::vector w(weights, weights + weight_count); - std::unique_ptr dsp = std::make_unique( - in_channels, out_channels, channels, dilations, batchnorm, activation_config, w, sample_rate, groups); LoadResult result; - result.dsp = std::move(dsp); - result.weights_consumed = weight_count; + result.config = std::move(cfg); + result.weights.assign(weights, weights + weight_count); return result; } // --- WaveNet --- -LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta) +LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_count, const nam::ModelMetadata& meta) { - uint8_t in_channels = r.read_u8(); + nam::wavenet::WaveNetConfig wc; + wc.in_channels = r.read_u8(); uint8_t has_head = r.read_u8(); uint8_t num_layer_arrays = r.read_u8(); uint8_t has_condition_dsp = r.read_u8(); - bool with_head = (has_head != 0); - double sample_rate = meta.sample_rate; + wc.with_head = (has_head != 0); // Condition DSP - std::unique_ptr condition_dsp; size_t cdsp_weights_consumed = 0; if (has_condition_dsp) @@ -233,25 +236,16 @@ LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_cou uint32_t cdsp_weight_count = r.read_u32(); // Read condition DSP metadata (48 bytes) - ParsedMetadata cdsp_meta = read_metadata_block(r); + ParsedMetadata cdsp_pm = read_metadata_block(r); + nam::ModelMetadata cdsp_meta = to_model_metadata(cdsp_pm); - // Load condition DSP model recursively + // Load condition DSP model recursively via create_dsp LoadResult cdsp_result = load_model(r, weights, cdsp_weight_count, cdsp_meta); - condition_dsp = std::move(cdsp_result.dsp); - cdsp_weights_consumed = cdsp_result.weights_consumed; - - // Apply metadata to condition DSP - if (cdsp_meta.meta_flags & META_HAS_LOUDNESS) - condition_dsp->SetLoudness(cdsp_meta.loudness); - if (cdsp_meta.meta_flags & META_HAS_INPUT_LEVEL) - condition_dsp->SetInputLevel(cdsp_meta.input_level); - if (cdsp_meta.meta_flags & META_HAS_OUTPUT_LEVEL) - condition_dsp->SetOutputLevel(cdsp_meta.output_level); - condition_dsp->prewarm(); + wc.condition_dsp = nam::create_dsp(std::move(cdsp_result.config), std::move(cdsp_result.weights), cdsp_meta); + cdsp_weights_consumed = cdsp_weight_count; } // Parse layer array params - std::vector layer_array_params; for (int la = 0; la < num_layer_arrays; la++) { uint16_t input_size = r.read_u16(); @@ -322,28 +316,25 @@ LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_cou nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups); nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups); - layer_array_params.emplace_back(input_size, condition_size, head_size, la_channels, bottleneck, kernel_size, - std::move(dilations), std::move(activation_configs), std::move(gating_modes), - head_bias, groups_input, groups_input_mixin, layer1x1_params, head1x1_params, - std::move(secondary_activation_configs), conv_pre_film, conv_post_film, - input_mixin_pre_film, input_mixin_post_film, activation_pre_film, - activation_post_film, layer1x1_post_film, head1x1_post_film); + wc.layer_array_params.emplace_back(input_size, condition_size, head_size, la_channels, bottleneck, kernel_size, + std::move(dilations), std::move(activation_configs), std::move(gating_modes), + head_bias, groups_input, groups_input_mixin, layer1x1_params, head1x1_params, + std::move(secondary_activation_configs), conv_pre_film, conv_post_film, + input_mixin_pre_film, input_mixin_post_film, activation_pre_film, + activation_post_film, layer1x1_post_film, head1x1_post_film); } + // head_scale is the last weight value, but set_weights_ will overwrite it. + // Pass 0.0f; set_weights_ will set the correct value from weights. + wc.head_scale = 0.0f; + // Build wavenet weights (excluding condition DSP weights which were consumed earlier) const float* wavenet_weights_ptr = weights + cdsp_weights_consumed; size_t wavenet_weight_count = weight_count - cdsp_weights_consumed; - std::vector wavenet_weights(wavenet_weights_ptr, wavenet_weights_ptr + wavenet_weight_count); - - // head_scale is the last weight value, but the constructor takes it as a param too - // (it gets overridden by set_weights_). Pass 0.0f; set_weights_ will set the correct value. - std::unique_ptr dsp = std::make_unique( - in_channels, layer_array_params, 0.0f, with_head, std::move(wavenet_weights), std::move(condition_dsp), - sample_rate); LoadResult result; - result.dsp = std::move(dsp); - result.weights_consumed = weight_count; + result.config = std::move(wc); + result.weights.assign(wavenet_weights_ptr, wavenet_weights_ptr + wavenet_weight_count); return result; } @@ -351,7 +342,7 @@ LoadResult load_wavenet(BinaryReader& r, const float* weights, size_t weight_cou // Dispatch to architecture-specific loader // ============================================================================= -LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const ParsedMetadata& meta) +LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count, const nam::ModelMetadata& meta) { uint8_t arch = r.read_u8(); r.read_u8(); // reserved @@ -359,9 +350,9 @@ LoadResult load_model(BinaryReader& r, const float* weights, size_t weight_count switch (arch) { - case ARCH_LINEAR: return load_linear(r, weights, weight_count, meta.sample_rate); - case ARCH_LSTM: return load_lstm(r, weights, weight_count, meta.sample_rate); - case ARCH_CONVNET: return load_convnet(r, weights, weight_count, meta.sample_rate); + case ARCH_LINEAR: return load_linear(r, weights, weight_count); + case ARCH_LSTM: return load_lstm(r, weights, weight_count); + case ARCH_CONVNET: return load_convnet(r, weights, weight_count); case ARCH_WAVENET: return load_wavenet(r, weights, weight_count, meta); default: throw std::runtime_error("NAMB: unknown architecture ID " + std::to_string(arch)); } @@ -414,12 +405,11 @@ std::unique_ptr nam::get_dsp_namb(const uint8_t* data, size_t size) // Read metadata block (at offset 32) BinaryReader meta_reader(data + FILE_HEADER_SIZE, METADATA_BLOCK_SIZE); - ParsedMetadata meta = read_metadata_block(meta_reader); + ParsedMetadata pm = read_metadata_block(meta_reader); + ModelMetadata meta = to_model_metadata(pm); // Verify config version - std::string version_str = std::to_string(meta.version_major) + "." + std::to_string(meta.version_minor) + "." - + std::to_string(meta.version_patch); - nam::verify_config_version(version_str); + nam::verify_config_version(meta.version); // Get weight data pointer const float* weights = reinterpret_cast(data + weights_offset); @@ -428,20 +418,9 @@ std::unique_ptr nam::get_dsp_namb(const uint8_t* data, size_t size) size_t model_data_size = weights_offset - MODEL_BLOCK_OFFSET; BinaryReader model_reader(data + MODEL_BLOCK_OFFSET, model_data_size); - // Load the model + // Load model config and weights, then construct via unified path LoadResult result = load_model(model_reader, weights, total_weight_count, meta); - - // Apply metadata - if (meta.meta_flags & META_HAS_LOUDNESS) - result.dsp->SetLoudness(meta.loudness); - if (meta.meta_flags & META_HAS_INPUT_LEVEL) - result.dsp->SetInputLevel(meta.input_level); - if (meta.meta_flags & META_HAS_OUTPUT_LEVEL) - result.dsp->SetOutputLevel(meta.output_level); - - result.dsp->prewarm(); - - return std::move(result.dsp); + return create_dsp(std::move(result.config), std::move(result.weights), meta); } std::unique_ptr nam::get_dsp_namb(const std::filesystem::path& filename) diff --git a/NAM/lstm.cpp b/NAM/lstm.cpp index d162d55..93e4d33 100644 --- a/NAM/lstm.cpp +++ b/NAM/lstm.cpp @@ -163,18 +163,25 @@ void nam::lstm::LSTM::_process_sample() this->_output.noalias() += this->_head_bias; } +// Config parser +nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config) +{ + LSTMConfig c; + c.num_layers = config["num_layers"]; + c.input_size = config["input_size"]; + c.hidden_size = config["hidden_size"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + // Factory to instantiate from nlohmann json std::unique_ptr nam::lstm::Factory(const nlohmann::json& config, std::vector& weights, const double expectedSampleRate) { - const int num_layers = config["num_layers"]; - const int input_size = config["input_size"]; - const int hidden_size = config["hidden_size"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate); + auto c = parse_config_json(config); + return std::make_unique(c.in_channels, c.out_channels, c.num_layers, c.input_size, c.hidden_size, + weights, expectedSampleRate); } // Register the factory diff --git a/NAM/lstm.h b/NAM/lstm.h index d97de20..fa00d4d 100644 --- a/NAM/lstm.h +++ b/NAM/lstm.h @@ -95,6 +95,21 @@ class LSTM : public DSP Eigen::VectorXf _output; }; +/// \brief Configuration for an LSTM model +struct LSTMConfig +{ + int num_layers; + int input_size; + int hidden_size; + int in_channels; + int out_channels; +}; + +/// \brief Parse LSTM configuration from JSON +/// \param config JSON configuration object +/// \return LSTMConfig +LSTMConfig parse_config_json(const nlohmann::json& config); + /// \brief Factory function to instantiate LSTM from JSON /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/NAM/model_config.h b/NAM/model_config.h new file mode 100644 index 0000000..cfc5b86 --- /dev/null +++ b/NAM/model_config.h @@ -0,0 +1,51 @@ +#pragma once +// Unified model configuration types for both JSON and binary loaders. +// No circular dependencies: architecture headers define config structs, +// this header combines them into a variant. + +#include +#include +#include +#include +#include + +#include "convnet.h" +#include "dsp.h" +#include "lstm.h" +#include "wavenet.h" + +namespace nam +{ + +/// \brief Metadata common to all model formats +struct ModelMetadata +{ + std::string version; + double sample_rate = -1.0; + std::optional loudness; + std::optional input_level; + std::optional output_level; +}; + +/// \brief Variant of all architecture configs +using ModelConfig = std::variant; + +/// \brief Construct a DSP object from a typed config, weights, and metadata +/// +/// This is the single construction path used by both JSON and binary loaders. +/// Handles construction, metadata application, and prewarm. +/// \param config Architecture-specific configuration (variant) +/// \param weights Model weights (taken by value to allow move for WaveNet) +/// \param metadata Model metadata (version, sample rate, loudness, levels) +/// \return Unique pointer to a DSP object +std::unique_ptr create_dsp(ModelConfig config, std::vector weights, const ModelMetadata& metadata); + +/// \brief Parse a ModelConfig from a JSON architecture name and config block +/// \param architecture Architecture name string (e.g., "WaveNet", "LSTM") +/// \param config JSON config block for this architecture +/// \param sample_rate Expected sample rate from metadata +/// \return ModelConfig variant +ModelConfig parse_model_config_json(const std::string& architecture, const nlohmann::json& config, + double sample_rate); + +} // namespace nam diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 7d9b5d0..0fc04b4 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -568,41 +568,43 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con } } -// Factory to instantiate from nlohmann json -std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) +// Config parser - extracts all configuration from JSON without constructing the DSP +nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json& config, + const double expectedSampleRate) { - std::unique_ptr condition_dsp = nullptr; + WaveNetConfig wc; + + // Condition DSP (eagerly built via get_dsp) if (config.find("condition_dsp") != config.end()) { const nlohmann::json& condition_dsp_json = config["condition_dsp"]; - condition_dsp = nam::get_dsp(condition_dsp_json); - if (condition_dsp->GetExpectedSampleRate() != expectedSampleRate) + wc.condition_dsp = nam::get_dsp(condition_dsp_json); + if (wc.condition_dsp->GetExpectedSampleRate() != expectedSampleRate) { std::stringstream ss; - ss << "Condition DSP expected sample rate (" << condition_dsp->GetExpectedSampleRate() + ss << "Condition DSP expected sample rate (" << wc.condition_dsp->GetExpectedSampleRate() << ") doesn't match WaveNet expected sample rate (" << expectedSampleRate << "!\n"; throw std::runtime_error(ss.str().c_str()); } } - std::vector layer_array_params; + for (size_t i = 0; i < config["layers"].size(); i++) { nlohmann::json layer_config = config["layers"][i]; - const int groups = layer_config.value("groups_input", 1); // defaults to 1 - const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1 + const int groups = layer_config.value("groups_input", 1); + const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); const int channels = layer_config["channels"]; - const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present + const int bottleneck = layer_config.value("bottleneck", channels); // Parse layer1x1 parameters - bool layer1x1_active = true; // default to active if not present + bool layer1x1_active = true; int layer1x1_groups = 1; if (layer_config.find("layer1x1") != layer_config.end()) { const auto& layer1x1_config = layer_config["layer1x1"]; - layer1x1_active = layer1x1_config["active"]; // default to active + layer1x1_active = layer1x1_config["active"]; layer1x1_groups = layer1x1_config["groups"]; } nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups); @@ -618,7 +620,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st std::vector activation_configs; if (layer_config["activation"].is_array()) { - // Array of activation configs for (const auto& activation_json : layer_config["activation"]) { activation_configs.push_back(activations::ActivationConfig::from_json(activation_json)); @@ -632,12 +633,12 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single activation config - duplicate it for all layers const activations::ActivationConfig activation_config = activations::ActivationConfig::from_json(layer_config["activation"]); activation_configs.resize(num_layers, activation_config); } - // Parse gating mode(s) - support both single value and array, and old "gated" boolean + + // Parse gating mode(s) std::vector gating_modes; std::vector secondary_activation_configs; @@ -656,21 +657,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st { if (layer_config["gating_mode"].is_array()) { - // Array of gating modes for (const auto& gating_mode_json : layer_config["gating_mode"]) { std::string gating_mode_str = gating_mode_json.get(); GatingMode mode = parse_gating_mode_str(gating_mode_str); gating_modes.push_back(mode); - // Parse corresponding secondary activation if gating is enabled if (mode != GatingMode::NONE) { if (layer_config.find("secondary_activation") != layer_config.end()) { if (layer_config["secondary_activation"].is_array()) { - // Array of secondary activations - use corresponding index if (gating_modes.size() > layer_config["secondary_activation"].size()) { throw std::runtime_error("Layer array " + std::to_string(i) @@ -682,21 +680,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single secondary activation - use for all gated layers secondary_activation_configs.push_back( activations::ActivationConfig::from_json(layer_config["secondary_activation"])); } } else { - // Default to Sigmoid for backward compatibility secondary_activation_configs.push_back( activations::ActivationConfig::simple(activations::ActivationType::Sigmoid)); } } else { - // NONE mode - use empty config secondary_activation_configs.push_back(activations::ActivationConfig{}); } } @@ -706,7 +701,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st + std::to_string(gating_modes.size()) + ") must match dilations size (" + std::to_string(num_layers) + ")"); } - // Validate secondary_activation array size if it's an array if (layer_config.find("secondary_activation") != layer_config.end() && layer_config["secondary_activation"].is_array()) { @@ -720,12 +714,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single gating mode - duplicate for all layers std::string gating_mode_str = layer_config["gating_mode"].get(); GatingMode gating_mode = parse_gating_mode_str(gating_mode_str); gating_modes.resize(num_layers, gating_mode); - // Parse secondary activation activations::ActivationConfig secondary_activation_config; if (gating_mode != GatingMode::NONE) { @@ -736,7 +728,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to Sigmoid for backward compatibility secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid); } } @@ -745,7 +736,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else if (layer_config.find("gated") != layer_config.end()) { - // Backward compatibility: convert old "gated" boolean to new enum bool gated = layer_config["gated"]; GatingMode gating_mode = gated ? GatingMode::GATED : GatingMode::NONE; gating_modes.resize(num_layers, gating_mode); @@ -763,7 +753,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to NONE for all layers gating_modes.resize(num_layers, GatingMode::NONE); secondary_activation_configs.resize(num_layers, activations::ActivationConfig{}); } @@ -792,11 +781,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const nlohmann::json& film_config = layer_config[key]; bool active = film_config.value("active", true); bool shift = film_config.value("shift", true); - int groups = film_config.value("groups", 1); - return nam::wavenet::_FiLMParams(active, shift, groups); + int film_groups = film_config.value("groups", 1); + return nam::wavenet::_FiLMParams(active, shift, film_groups); }; - // Parse FiLM parameters nam::wavenet::_FiLMParams conv_pre_film_params = parse_film_params("conv_pre_film"); nam::wavenet::_FiLMParams conv_post_film_params = parse_film_params("conv_post_film"); nam::wavenet::_FiLMParams input_mixin_pre_film_params = parse_film_params("input_mixin_pre_film"); @@ -806,32 +794,37 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film"); nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film"); - // Validation: if layer1x1_post_film is active, layer1x1 must also be active if (_layer1x1_post_film_params.active && !layer1x1_active) { throw std::runtime_error("Layer array " + std::to_string(i) + ": layer1x1_post_film cannot be active when layer1x1.active is false"); } - layer_array_params.push_back(nam::wavenet::LayerArrayParams( + wc.layer_array_params.push_back(nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params, head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params)); } - const bool with_head = !config["head"].is_null(); - const float head_scale = config["head_scale"]; - if (layer_array_params.empty()) + wc.with_head = !config["head"].is_null(); + wc.head_scale = config["head_scale"]; + wc.in_channels = config.value("in_channels", 1); + + if (wc.layer_array_params.empty()) throw std::runtime_error("WaveNet config requires at least one layer array"); - // Backward compatibility: assume 1 input channel - const int in_channels = config.value("in_channels", 1); + return wc; +} - // out_channels is determined from last layer array's head_size - return std::make_unique( - in_channels, layer_array_params, head_scale, with_head, weights, std::move(condition_dsp), expectedSampleRate); +// Factory to instantiate from nlohmann json +std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + auto wc = parse_config_json(config, expectedSampleRate); + return std::make_unique(wc.in_channels, wc.layer_array_params, wc.head_scale, wc.with_head, + weights, std::move(wc.condition_dsp), expectedSampleRate); } // Register the factory diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 63e1378..b75d15f 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -713,6 +713,29 @@ class WaveNet : public DSP int PrewarmSamples() override { return mPrewarmSamples; }; }; +/// \brief Configuration for a WaveNet model +struct WaveNetConfig +{ + int in_channels; + std::vector layer_array_params; + float head_scale; + bool with_head; + std::unique_ptr condition_dsp; + + // Move-only due to unique_ptr + WaveNetConfig() = default; + WaveNetConfig(WaveNetConfig&&) = default; + WaveNetConfig& operator=(WaveNetConfig&&) = default; + WaveNetConfig(const WaveNetConfig&) = delete; + WaveNetConfig& operator=(const WaveNetConfig&) = delete; +}; + +/// \brief Parse WaveNet configuration from JSON +/// \param config JSON configuration object +/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) +/// \return WaveNetConfig +WaveNetConfig parse_config_json(const nlohmann::json& config, const double expectedSampleRate); + /// \brief Factory function to instantiate WaveNet from JSON configuration /// \param config JSON configuration object /// \param weights Model weights vector diff --git a/tools/test/test_namb.cpp b/tools/test/test_namb.cpp index 88c36cb..819fcf4 100644 --- a/tools/test/test_namb.cpp +++ b/tools/test/test_namb.cpp @@ -4,7 +4,6 @@ // - Size verification: NAMB < NAM for all example models #include -#include #include #include #include @@ -106,12 +105,7 @@ static void test_roundtrip_for_file(const std::string& nam_path) { std::filesystem::path model_path(nam_path); if (!std::filesystem::exists(model_path)) - { - std::cerr << " Skipping (not found): " << nam_path << std::endl; return; - } - - std::cout << " Testing round-trip: " << model_path.filename().string() << std::endl; // Load JSON model std::unique_ptr json_model = nam::get_dsp(model_path); @@ -130,12 +124,12 @@ static void test_roundtrip_for_file(const std::string& nam_path) // Use the nam2namb tool to create the .namb file // Since we can't easily call the tool's function, we'll use system() or // construct the binary ourselves. For the test, let's shell out. - std::string cmd = "./build/tools/nam2namb " + model_path.string() + " " + namb_path.string() + " 2>&1"; + std::string cmd = "./build/tools/nam2namb " + model_path.string() + " " + namb_path.string() + " > /dev/null 2>&1"; int ret = system(cmd.c_str()); if (ret != 0) { // Try relative path from where tests might be run - cmd = "nam2namb " + model_path.string() + " " + namb_path.string() + " 2>&1"; + cmd = "nam2namb " + model_path.string() + " " + namb_path.string() + " > /dev/null 2>&1"; ret = system(cmd.c_str()); } assert(ret == 0); @@ -177,14 +171,10 @@ static void test_roundtrip_for_file(const std::string& nam_path) // Clean up temp file std::filesystem::remove(namb_path); - - std::cout << " PASS" << std::endl; } void test_roundtrip() { - std::cout << "test_namb::test_roundtrip" << std::endl; - // Test all available example models const std::vector models = {"example_models/wavenet.nam", "example_models/lstm.nam", "example_models/wavenet_condition_dsp.nam", @@ -202,8 +192,6 @@ void test_roundtrip() void test_bad_magic() { - std::cout << "test_namb::test_bad_magic" << std::endl; - // Create minimal data with wrong magic std::vector data(128, 0); data[0] = 'X'; // Wrong magic @@ -219,13 +207,10 @@ void test_bad_magic() assert(std::string(e.what()).find("magic") != std::string::npos); } assert(threw); - std::cout << " PASS" << std::endl; } void test_truncated_file() { - std::cout << "test_namb::test_truncated_file" << std::endl; - // File too small for header std::vector data(16, 0); // Set magic correctly @@ -242,13 +227,10 @@ void test_truncated_file() threw = true; } assert(threw); - std::cout << " PASS" << std::endl; } void test_wrong_version() { - std::cout << "test_namb::test_wrong_version" << std::endl; - // Create data with wrong format version std::vector data(128, 0); uint32_t magic = nam::namb::MAGIC; @@ -267,29 +249,20 @@ void test_wrong_version() assert(std::string(e.what()).find("version") != std::string::npos); } assert(threw); - std::cout << " PASS" << std::endl; } void test_bad_checksum() { - std::cout << "test_namb::test_bad_checksum" << std::endl; - // First create a valid .namb file, then corrupt it std::filesystem::path nam_path("example_models/lstm.nam"); if (!std::filesystem::exists(nam_path)) - { - std::cerr << " Skipping (lstm.nam not found)" << std::endl; return; - } std::filesystem::path namb_path("example_models/lstm_test_bad_crc.namb"); - std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " 2>&1"; + std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " > /dev/null 2>&1"; int ret = system(cmd.c_str()); if (ret != 0) - { - std::cerr << " Skipping (nam2namb not available)" << std::endl; return; - } // Read the .namb file auto data = read_file_bytes(namb_path); @@ -313,7 +286,6 @@ void test_bad_checksum() assert(threw); std::filesystem::remove(namb_path); - std::cout << " PASS" << std::endl; } // ============================================================================= @@ -322,8 +294,6 @@ void test_bad_checksum() void test_size_reduction() { - std::cout << "test_namb::test_size_reduction" << std::endl; - const std::vector models = {"example_models/wavenet.nam", "example_models/lstm.nam", "example_models/wavenet_condition_dsp.nam", "example_models/wavenet_a2_max.nam"}; @@ -337,7 +307,7 @@ void test_size_reduction() std::filesystem::path namb_path = nam_path; namb_path.replace_extension(".namb"); - std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " 2>&1"; + std::string cmd = "./build/tools/nam2namb " + nam_path.string() + " " + namb_path.string() + " > /dev/null 2>&1"; int ret = system(cmd.c_str()); if (ret != 0) continue; @@ -346,9 +316,6 @@ void test_size_reduction() size_t namb_size = std::filesystem::file_size(namb_path); double reduction = 100.0 * (1.0 - (double)namb_size / (double)nam_size); - std::cout << " " << nam_path.filename().string() << ": " << nam_size << " -> " << namb_size << " (" - << std::fixed << std::setprecision(1) << reduction << "% reduction)" << std::endl; - // .namb should always be smaller than .nam assert(namb_size < nam_size); @@ -357,7 +324,6 @@ void test_size_reduction() std::filesystem::remove(namb_path); } - std::cout << " PASS" << std::endl; } // ============================================================================= @@ -366,8 +332,6 @@ void test_size_reduction() void test_crc32() { - std::cout << "test_namb::test_crc32" << std::endl; - // Test known CRC32 values const uint8_t test1[] = "123456789"; uint32_t crc1 = nam::namb::crc32(test1, 9); @@ -377,8 +341,6 @@ void test_crc32() // Empty data uint32_t crc_empty = nam::namb::crc32(nullptr, 0); assert(crc_empty == 0x00000000u); - - std::cout << " PASS" << std::endl; } }; // namespace test_namb From d68ea1689984f7331520c6d84a91a0891826d520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 13:56:43 -0800 Subject: [PATCH 3/4] Small fix to botched merge --- NAM/wavenet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index d951c8b..6f39a9d 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -575,7 +575,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json WaveNetConfig wc; // Condition DSP (eagerly built via get_dsp) - if (config.find("condition_dsp") != config.end()) && !config["condition_dsp"].is_null()) + if ((config.find("condition_dsp") != config.end()) && !config["condition_dsp"].is_null()) { const nlohmann::json& condition_dsp_json = config["condition_dsp"]; wc.condition_dsp = nam::get_dsp(condition_dsp_json); From ff8cd166efd28d13cd9208ebdb5c8c2df5d6d40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 19:40:09 -0800 Subject: [PATCH 4/4] Fixing another merge issue --- NAM/wavenet.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index cc7bcc7..6f39a9d 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -807,8 +807,6 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params)); } - const bool with_head = config.find("head") != config.end() && !config["head"].is_null(); - const float head_scale = config["head_scale"]; wc.with_head = !config["head"].is_null(); wc.head_scale = config["head_scale"];