diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index efd02ff..6eb74a3 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -820,7 +820,7 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params)); } - const bool with_head = !config["head"].is_null(); + const bool with_head = config.find("head") != config.end() && !config["head"].is_null(); const float head_scale = config["head_scale"]; if (layer_array_params.empty()) diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 1c8c34a..f26edb6 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -19,6 +19,7 @@ #include "test/test_wavenet/test_condition_processing.cpp" #include "test/test_wavenet/test_head1x1.cpp" #include "test/test_wavenet/test_layer1x1.cpp" +#include "test/test_wavenet/test_factory.cpp" #include "test/test_gating_activations.cpp" #include "test/test_wavenet_gating_compatibility.cpp" #include "test/test_blending_detailed.cpp" @@ -169,6 +170,7 @@ int main() test_wavenet::test_layer1x1::test_layer1x1_post_film_inactive_with_layer1x1_inactive(); test_wavenet::test_layer1x1::test_layer1x1_gated(); test_wavenet::test_layer1x1::test_layer1x1_groups(); + test_wavenet::test_factory::test_factory_without_head_key(); test_wavenet::test_allocation_tracking_pass(); test_wavenet::test_allocation_tracking_fail(); test_wavenet::test_conv1d_process_realtime_safe(); diff --git a/tools/test/test_wavenet/test_factory.cpp b/tools/test/test_wavenet/test_factory.cpp new file mode 100644 index 0000000..ed06c22 --- /dev/null +++ b/tools/test/test_wavenet/test_factory.cpp @@ -0,0 +1,73 @@ +// Tests for WaveNet Factory + +#include +#include +#include +#include + +#include "json.hpp" + +#include "NAM/get_dsp.h" +#include "NAM/wavenet.h" + +namespace test_wavenet +{ +namespace test_factory +{ +/// Asserts that the model is instantiated correctly when no "head" key is provided. +/// The deprecated "head" key is optional; when absent, with_head should be false. +void test_factory_without_head_key() +{ + // Minimal WaveNet config - deliberately omits the "head" key entirely. + // Same structure as wavenet.nam but without "head" in config. + const std::string configStr = R"({ + "version": "0.5.4", + "metadata": {}, + "architecture": "WaveNet", + "config": { + "layers": [{ + "input_size": 1, + "condition_size": 1, + "head_size": 1, + "channels": 1, + "kernel_size": 1, + "dilations": [1], + "activation": "ReLU", + "gated": false, + "head_bias": false + }], + "head_scale": 1.0 + }, + "weights": [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0], + "sample_rate": 48000 + })"; + + nlohmann::json j = nlohmann::json::parse(configStr); + + // Verify the config does not contain "head" key + assert(j["config"].find("head") == j["config"].end()); + + // Load model via get_dsp - exercises Factory path + std::unique_ptr dsp = nam::get_dsp(j); + assert(dsp != nullptr); + + // Process audio to verify model works correctly + const int numFrames = 4; + const int maxBufferSize = 64; + dsp->Reset(48000.0, maxBufferSize); + + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + NAM_SAMPLE* inputPtrs[] = {input.data()}; + NAM_SAMPLE* outputPtrs[] = {output.data()}; + + dsp->process(inputPtrs, outputPtrs, numFrames); + + assert(static_cast(output.size()) == numFrames); + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} +}; // namespace test_factory +}; // namespace test_wavenet