diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index fc7c151..5d02b21 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -322,25 +322,38 @@ void nam::convnet::ConvNet::_rewind_buffers_() this->Buffer::_rewind_buffers_(); } -// Factory -std::unique_ptr nam::convnet::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) +// Config parser +nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config) { - const int channels = config["channels"]; - const std::vector dilations = config["dilations"]; - const bool batchnorm = config["batchnorm"]; - // Parse JSON into typed ActivationConfig at model loading boundary - const activations::ActivationConfig activation_config = - activations::ActivationConfig::from_json(config["activation"]); - const int groups = config.value("groups", 1); // defaults to 1 - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups); + ConvNetConfig c; + c.channels = config["channels"]; + c.dilations = config["dilations"].get>(); + c.batchnorm = config["batchnorm"]; + c.activation = activations::ActivationConfig::from_json(config["activation"]); + c.groups = config.value("groups", 1); + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + +// ConvNetConfig::create() +std::unique_ptr nam::convnet::ConvNetConfig::create(std::vector weights, double sampleRate) +{ + return std::make_unique(in_channels, out_channels, channels, dilations, batchnorm, activation, + weights, sampleRate, groups); +} + +// Config parser for ConfigParserRegistry +std::unique_ptr nam::convnet::create_config(const nlohmann::json& config, double sampleRate) +{ + (void)sampleRate; + auto c = std::make_unique(); + auto parsed = parse_config_json(config); + *c = parsed; + return c; } namespace { -static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory); +static nam::ConfigParserHelper _register_ConvNet("ConvNet", nam::convnet::create_config); } diff --git a/NAM/convnet.h b/NAM/convnet.h index 0d963df..c1d7c1a 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -165,13 +165,27 @@ class ConvNet : public Buffer int PrewarmSamples() override { return mPrewarmSamples; }; }; -/// \brief Factory function to instantiate ConvNet from JSON +/// \brief Configuration for a ConvNet model +struct ConvNetConfig : public ModelConfig +{ + int channels; + std::vector dilations; + bool batchnorm; + activations::ActivationConfig activation; + int groups; + int in_channels; + int out_channels; + + std::unique_ptr create(std::vector weights, double sampleRate) override; +}; + +/// \brief Parse ConvNet configuration from JSON /// \param config JSON configuration object -/// \param weights Model weights vector -/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) -/// \return Unique pointer to a DSP object (ConvNet instance) -std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate); +/// \return ConvNetConfig +ConvNetConfig parse_config_json(const nlohmann::json& config); + +/// \brief Config parser for ConfigParserRegistry +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate); }; // namespace convnet }; // namespace nam diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 05dab09..e9e3629 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -300,16 +300,37 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num nam::Buffer::_advance_input_buffer_(num_frames); } -// Factory -std::unique_ptr nam::linear::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) -{ - const int receptive_field = config["receptive_field"]; - const bool bias = config["bias"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate); +// Config parser +nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config) +{ + LinearConfig c; + c.receptive_field = config["receptive_field"]; + c.bias = config["bias"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; +} + +// LinearConfig::create() +std::unique_ptr nam::linear::LinearConfig::create(std::vector weights, double sampleRate) +{ + return std::make_unique(in_channels, out_channels, receptive_field, bias, weights, sampleRate); +} + +// Config parser for ConfigParserRegistry +std::unique_ptr nam::linear::create_config(const nlohmann::json& config, double sampleRate) +{ + (void)sampleRate; + auto c = std::make_unique(); + auto parsed = parse_config_json(config); + *c = parsed; + return c; +} + +// Register the config parser +namespace +{ +static nam::ConfigParserHelper _register_Linear("Linear", nam::linear::create_config); } // NN modules ================================================================= diff --git a/NAM/dsp.h b/NAM/dsp.h index 1313ad9..15bc9b8 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -11,6 +11,7 @@ #include "activations.h" #include "json.hpp" +#include "model_config.h" #ifdef NAM_SAMPLE_FLOAT #define NAM_SAMPLE float @@ -258,13 +259,28 @@ class Linear : public Buffer namespace linear { -/// \brief Factory function to instantiate Linear model from JSON + +/// \brief Configuration for a Linear model +struct LinearConfig : public ModelConfig +{ + int receptive_field; + bool bias; + int in_channels; + int out_channels; + + std::unique_ptr create(std::vector weights, double sampleRate) override; +}; + +/// \brief Parse Linear configuration from JSON +/// \param config JSON configuration object +/// \return LinearConfig +LinearConfig parse_config_json(const nlohmann::json& config); + +/// \brief Config parser for ConfigParserRegistry /// \param config JSON configuration object -/// \param weights Model weights vector -/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) -/// \return Unique pointer to a DSP object (Linear instance) -std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate); +/// \param sampleRate Expected sample rate in Hz +/// \return unique_ptr wrapping a LinearConfig +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate); } // namespace linear // NN modules ================================================================= diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 57d0fbd..e1e6409 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -2,15 +2,12 @@ #include #include #include -#include #include "dsp.h" #include "registry.h" #include "json.hpp" -#include "lstm.h" -#include "convnet.h" -#include "wavenet.h" #include "get_dsp.h" +#include "model_config.h" namespace nam { @@ -146,62 +143,67 @@ std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConf return get_dsp(conf); } -struct OptionalValue +// ============================================================================= +// Unified construction path +// ============================================================================= + +std::unique_ptr parse_model_config_json(const std::string& architecture, const nlohmann::json& config, + double sample_rate) +{ + return ConfigParserRegistry::instance().parse(architecture, config, sample_rate); +} + +namespace { - bool have = false; - double value = 0.0; -}; + +void apply_metadata(DSP& dsp, const ModelMetadata& metadata) +{ + if (metadata.loudness.has_value()) + dsp.SetLoudness(metadata.loudness.value()); + if (metadata.input_level.has_value()) + dsp.SetInputLevel(metadata.input_level.value()); + if (metadata.output_level.has_value()) + dsp.SetOutputLevel(metadata.output_level.value()); +} + +} // anonymous namespace + +std::unique_ptr create_dsp(std::unique_ptr config, std::vector weights, + const ModelMetadata& metadata) +{ + auto out = config->create(std::move(weights), metadata.sample_rate); + apply_metadata(*out, metadata); + out->prewarm(); + return out; +} + +// ============================================================================= +// get_dsp(dspData&) — now uses unified path +// ============================================================================= std::unique_ptr get_dsp(dspData& conf) { verify_config_version(conf.version); - auto& architecture = conf.architecture; - nlohmann::json& config = conf.config; - std::vector& weights = conf.weights; - OptionalValue loudness, inputLevel, outputLevel; - - auto AssignOptional = [&conf](const std::string key, OptionalValue& v) { - if (conf.metadata.find(key) != conf.metadata.end()) - { - if (!conf.metadata[key].is_null()) - { - v.value = conf.metadata[key]; - v.have = true; - } - } - }; + // Extract metadata from JSON + ModelMetadata metadata; + metadata.version = conf.version; + metadata.sample_rate = conf.expected_sample_rate; if (!conf.metadata.is_null()) { - AssignOptional("loudness", loudness); - AssignOptional("input_level_dbu", inputLevel); - AssignOptional("output_level_dbu", outputLevel); - } - const double expectedSampleRate = conf.expected_sample_rate; - - // Initialize using registry-based factory - std::unique_ptr out = - nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate); - - if (loudness.have) - { - out->SetLoudness(loudness.value); - } - if (inputLevel.have) - { - out->SetInputLevel(inputLevel.value); - } - if (outputLevel.have) - { - out->SetOutputLevel(outputLevel.value); + auto extract = [&conf](const std::string& key) -> std::optional { + if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null()) + return conf.metadata[key].get(); + return std::nullopt; + }; + metadata.loudness = extract("loudness"); + metadata.input_level = extract("input_level_dbu"); + metadata.output_level = extract("output_level_dbu"); } - // "pre-warm" the model to settle initial conditions - // Can this be removed now that it's part of Reset()? - out->prewarm(); - - return out; + auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, conf.expected_sample_rate); + return create_dsp(std::move(model_config), std::move(conf.weights), metadata); } double get_sample_rate_from_nam_file(const nlohmann::json& j) diff --git a/NAM/lstm.cpp b/NAM/lstm.cpp index d162d55..dba6f45 100644 --- a/NAM/lstm.cpp +++ b/NAM/lstm.cpp @@ -163,22 +163,37 @@ void nam::lstm::LSTM::_process_sample() this->_output.noalias() += this->_head_bias; } -// Factory to instantiate from nlohmann json -std::unique_ptr nam::lstm::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) +// Config parser +nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config) { - const int num_layers = config["num_layers"]; - const int input_size = config["input_size"]; - const int hidden_size = config["hidden_size"]; - // Default to 1 channel in/out for backward compatibility - const int in_channels = config.value("in_channels", 1); - const int out_channels = config.value("out_channels", 1); - return std::make_unique( - in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate); + LSTMConfig c; + c.num_layers = config["num_layers"]; + c.input_size = config["input_size"]; + c.hidden_size = config["hidden_size"]; + c.in_channels = config.value("in_channels", 1); + c.out_channels = config.value("out_channels", 1); + return c; } -// Register the factory +// LSTMConfig::create() +std::unique_ptr nam::lstm::LSTMConfig::create(std::vector weights, double sampleRate) +{ + return std::make_unique(in_channels, out_channels, num_layers, input_size, hidden_size, weights, + sampleRate); +} + +// Config parser for ConfigParserRegistry +std::unique_ptr nam::lstm::create_config(const nlohmann::json& config, double sampleRate) +{ + (void)sampleRate; + auto c = std::make_unique(); + auto parsed = parse_config_json(config); + *c = parsed; + return c; +} + +// Register the config parser namespace { -static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory); +static nam::ConfigParserHelper _register_LSTM("LSTM", nam::lstm::create_config); } diff --git a/NAM/lstm.h b/NAM/lstm.h index d97de20..88b7527 100644 --- a/NAM/lstm.h +++ b/NAM/lstm.h @@ -95,13 +95,25 @@ class LSTM : public DSP Eigen::VectorXf _output; }; -/// \brief Factory function to instantiate LSTM from JSON +/// \brief Configuration for an LSTM model +struct LSTMConfig : public ModelConfig +{ + int num_layers; + int input_size; + int hidden_size; + int in_channels; + int out_channels; + + std::unique_ptr create(std::vector weights, double sampleRate) override; +}; + +/// \brief Parse LSTM configuration from JSON /// \param config JSON configuration object -/// \param weights Model weights vector -/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) -/// \return Unique pointer to a DSP object (LSTM instance) -std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate); +/// \return LSTMConfig +LSTMConfig parse_config_json(const nlohmann::json& config); + +/// \brief Config parser for ConfigParserRegistry +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate); }; // namespace lstm }; // namespace nam diff --git a/NAM/model_config.h b/NAM/model_config.h new file mode 100644 index 0000000..32825a5 --- /dev/null +++ b/NAM/model_config.h @@ -0,0 +1,125 @@ +#pragma once +// Unified model configuration: abstract base class + config parser registry. +// No circular dependencies: forward-declares DSP (no #include "dsp.h"). + +#include +#include +#include +#include +#include +#include +#include + +#include "json.hpp" + +namespace nam +{ + +// Forward declaration — no #include "dsp.h" +class DSP; + +/// \brief Metadata common to all model formats +struct ModelMetadata +{ + std::string version; + double sample_rate = -1.0; + std::optional loudness; + std::optional input_level; + std::optional output_level; +}; + +/// \brief Abstract base class for architecture-specific configuration +/// +/// Each architecture defines a concrete config struct that inherits from this +/// and implements create() to construct the DSP object. +class ModelConfig +{ +public: + virtual ~ModelConfig() = default; + + /// \brief Construct a DSP object from this configuration + /// \param weights Model weights (taken by value to allow move for WaveNet) + /// \param sampleRate Expected sample rate in Hz + /// \return Unique pointer to a DSP object + virtual std::unique_ptr create(std::vector weights, double sampleRate) = 0; +}; + +/// \brief Function type for parsing a ModelConfig from JSON +using ConfigParserFunction = std::function(const nlohmann::json&, double)>; + +/// \brief Singleton registry mapping architecture names to config parser functions +/// +/// Both built-in and external architectures register here. There is one +/// construction path for all architectures. +class ConfigParserRegistry +{ +public: + static ConfigParserRegistry& instance() + { + static ConfigParserRegistry inst; + return inst; + } + + /// \brief Register a config parser for an architecture + /// \param name Architecture name (e.g., "WaveNet", "LSTM") + /// \param func Parser function that returns a unique_ptr + /// \throws std::runtime_error If the name is already registered + void registerParser(const std::string& name, ConfigParserFunction func) + { + if (parsers_.find(name) != parsers_.end()) + { + throw std::runtime_error("Config parser already registered for: " + name); + } + parsers_[name] = std::move(func); + } + + /// \brief Check whether an architecture name is registered + bool has(const std::string& name) const { return parsers_.find(name) != parsers_.end(); } + + /// \brief Parse a ModelConfig from an architecture name, JSON config, and sample rate + /// \throws std::runtime_error If no parser is registered for the given name + std::unique_ptr parse(const std::string& name, const nlohmann::json& config, double sampleRate) const + { + auto it = parsers_.find(name); + if (it == parsers_.end()) + { + throw std::runtime_error("No config parser registered for architecture: " + name); + } + return it->second(config, sampleRate); + } + +private: + std::unordered_map parsers_; +}; + +/// \brief Auto-registration helper for config parsers +/// +/// Create a static instance to register a config parser at program startup. +struct ConfigParserHelper +{ + ConfigParserHelper(const std::string& name, ConfigParserFunction func) + { + ConfigParserRegistry::instance().registerParser(name, std::move(func)); + } +}; + +/// \brief Construct a DSP object from a typed config, weights, and metadata +/// +/// This is the single construction path used by both JSON and binary loaders. +/// Handles construction, metadata application, and prewarm. +/// \param config Architecture-specific configuration (abstract base) +/// \param weights Model weights (taken by value to allow move for WaveNet) +/// \param metadata Model metadata (version, sample rate, loudness, levels) +/// \return Unique pointer to a DSP object +std::unique_ptr create_dsp(std::unique_ptr config, std::vector weights, + const ModelMetadata& metadata); + +/// \brief Parse a ModelConfig from a JSON architecture name and config block +/// \param architecture Architecture name string (e.g., "WaveNet", "LSTM") +/// \param config JSON config block for this architecture +/// \param sample_rate Expected sample rate from metadata +/// \return unique_ptr +std::unique_ptr parse_model_config_json(const std::string& architecture, const nlohmann::json& config, + double sample_rate); + +} // namespace nam diff --git a/NAM/registry.h b/NAM/registry.h index 0e90699..546772b 100644 --- a/NAM/registry.h +++ b/NAM/registry.h @@ -1,10 +1,14 @@ #pragma once -// Registry for DSP objects +// Registry for DSP objects — external code compatibility layer. +// +// The primary registration mechanism is ConfigParserRegistry in model_config.h. +// This header provides FactoryConfig (a ModelConfig wrapper around legacy factory +// functions) and factory::Helper (which registers into ConfigParserRegistry). #include #include -#include +#include #include #include "dsp.h" @@ -13,72 +17,50 @@ namespace nam { namespace factory { -/// \brief Factory function type for creating DSP objects +/// \brief Factory function type for creating DSP objects (legacy interface) using FactoryFunction = std::function(const nlohmann::json&, std::vector&, const double)>; -/// \brief Registry for factories that instantiate DSP objects +/// \brief ModelConfig wrapper around a legacy FactoryFunction /// -/// Singleton registry that maps architecture names to factory functions. -/// Allows dynamic registration of new DSP architectures. -class FactoryRegistry +/// Stores the factory function, JSON config, and sample rate so that +/// create() can delegate to the factory. This allows external code that +/// registers a FactoryFunction to work transparently with ConfigParserRegistry. +class FactoryConfig : public ModelConfig { public: - /// \brief Get the singleton instance - /// \return Reference to the factory registry instance - static FactoryRegistry& instance() + FactoryConfig(FactoryFunction factory, nlohmann::json config, double sampleRate) + : _factory(std::move(factory)) + , _config(std::move(config)) + , _sampleRate(sampleRate) { - static FactoryRegistry inst; - return inst; } - /// \brief Register a factory function for an architecture - /// \param key Architecture name (e.g., "WaveNet", "LSTM") - /// \param func Factory function that creates DSP instances - /// \throws std::runtime_error If the key is already registered - void registerFactory(const std::string& key, FactoryFunction func) + std::unique_ptr create(std::vector weights, double sampleRate) override { - // Assert that the key is not already registered - if (factories_.find(key) != factories_.end()) - { - throw std::runtime_error("Factory already registered for key: " + key); - } - factories_[key] = func; - } - - /// \brief Create a DSP object using a registered factory - /// \param name Architecture name - /// \param config JSON configuration object - /// \param weights Model weights vector - /// \param expectedSampleRate Expected sample rate in Hz - /// \return Unique pointer to a DSP object - /// \throws std::runtime_error If no factory is registered for the given name - std::unique_ptr create(const std::string& name, const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) const - { - auto it = factories_.find(name); - if (it != factories_.end()) - { - return it->second(config, weights, expectedSampleRate); - } - throw std::runtime_error("Factory not found for name: " + name); + (void)sampleRate; // Use stored sample rate from construction + return _factory(_config, weights, _sampleRate); } private: - std::unordered_map factories_; + FactoryFunction _factory; + nlohmann::json _config; + double _sampleRate; }; -/// \brief Registration helper for factories +/// \brief Registration helper for factories (external code compatibility) /// -/// Use this to register your factories. Create a static instance to automatically -/// register a factory when the program starts. +/// Wraps a FactoryFunction into a ConfigParserRegistry entry via FactoryConfig. +/// Use this to register external architectures. Create a static instance to +/// automatically register a factory when the program starts. struct Helper { - /// \brief Constructor that registers a factory - /// \param name Architecture name - /// \param factory Factory function Helper(const std::string& name, FactoryFunction factory) { - FactoryRegistry::instance().registerFactory(name, std::move(factory)); + // Capture factory by value in the lambda + ConfigParserRegistry::instance().registerParser( + name, [f = std::move(factory)](const nlohmann::json& config, double sampleRate) -> std::unique_ptr { + return std::make_unique(f, config, sampleRate); + }); } }; } // namespace factory diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 6eb74a3..3da0f72 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -568,41 +568,43 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con } } -// Factory to instantiate from nlohmann json -std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate) +// Config parser - extracts all configuration from JSON without constructing the DSP +nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json& config, + const double expectedSampleRate) { - std::unique_ptr condition_dsp = nullptr; - if (config.find("condition_dsp") != config.end() && !config["condition_dsp"].is_null()) + WaveNetConfig wc; + + // Condition DSP (eagerly built via get_dsp) + if ((config.find("condition_dsp") != config.end()) && !config["condition_dsp"].is_null()) { const nlohmann::json& condition_dsp_json = config["condition_dsp"]; - condition_dsp = nam::get_dsp(condition_dsp_json); - if (condition_dsp->GetExpectedSampleRate() != expectedSampleRate) + wc.condition_dsp = nam::get_dsp(condition_dsp_json); + if (wc.condition_dsp->GetExpectedSampleRate() != expectedSampleRate) { std::stringstream ss; - ss << "Condition DSP expected sample rate (" << condition_dsp->GetExpectedSampleRate() + ss << "Condition DSP expected sample rate (" << wc.condition_dsp->GetExpectedSampleRate() << ") doesn't match WaveNet expected sample rate (" << expectedSampleRate << "!\n"; throw std::runtime_error(ss.str().c_str()); } } - std::vector layer_array_params; + for (size_t i = 0; i < config["layers"].size(); i++) { nlohmann::json layer_config = config["layers"][i]; - const int groups = layer_config.value("groups_input", 1); // defaults to 1 - const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1 + const int groups = layer_config.value("groups_input", 1); + const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); const int channels = layer_config["channels"]; - const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present + const int bottleneck = layer_config.value("bottleneck", channels); // Parse layer1x1 parameters - bool layer1x1_active = true; // default to active if not present + bool layer1x1_active = true; int layer1x1_groups = 1; if (layer_config.find("layer1x1") != layer_config.end()) { const auto& layer1x1_config = layer_config["layer1x1"]; - layer1x1_active = layer1x1_config["active"]; // default to active + layer1x1_active = layer1x1_config["active"]; layer1x1_groups = layer1x1_config["groups"]; } nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups); @@ -618,7 +620,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st std::vector activation_configs; if (layer_config["activation"].is_array()) { - // Array of activation configs for (const auto& activation_json : layer_config["activation"]) { activation_configs.push_back(activations::ActivationConfig::from_json(activation_json)); @@ -632,12 +633,12 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single activation config - duplicate it for all layers const activations::ActivationConfig activation_config = activations::ActivationConfig::from_json(layer_config["activation"]); activation_configs.resize(num_layers, activation_config); } - // Parse gating mode(s) - support both single value and array, and old "gated" boolean + + // Parse gating mode(s) std::vector gating_modes; std::vector secondary_activation_configs; @@ -656,21 +657,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st { if (layer_config["gating_mode"].is_array()) { - // Array of gating modes for (const auto& gating_mode_json : layer_config["gating_mode"]) { std::string gating_mode_str = gating_mode_json.get(); GatingMode mode = parse_gating_mode_str(gating_mode_str); gating_modes.push_back(mode); - // Parse corresponding secondary activation if gating is enabled if (mode != GatingMode::NONE) { if (layer_config.find("secondary_activation") != layer_config.end()) { if (layer_config["secondary_activation"].is_array()) { - // Array of secondary activations - use corresponding index if (gating_modes.size() > layer_config["secondary_activation"].size()) { throw std::runtime_error("Layer array " + std::to_string(i) @@ -682,21 +680,18 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single secondary activation - use for all gated layers secondary_activation_configs.push_back( activations::ActivationConfig::from_json(layer_config["secondary_activation"])); } } else { - // Default to Sigmoid for backward compatibility secondary_activation_configs.push_back( activations::ActivationConfig::simple(activations::ActivationType::Sigmoid)); } } else { - // NONE mode - use empty config secondary_activation_configs.push_back(activations::ActivationConfig{}); } } @@ -706,7 +701,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st + std::to_string(gating_modes.size()) + ") must match dilations size (" + std::to_string(num_layers) + ")"); } - // Validate secondary_activation array size if it's an array if (layer_config.find("secondary_activation") != layer_config.end() && layer_config["secondary_activation"].is_array()) { @@ -720,12 +714,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Single gating mode - duplicate for all layers std::string gating_mode_str = layer_config["gating_mode"].get(); GatingMode gating_mode = parse_gating_mode_str(gating_mode_str); gating_modes.resize(num_layers, gating_mode); - // Parse secondary activation activations::ActivationConfig secondary_activation_config; if (gating_mode != GatingMode::NONE) { @@ -736,7 +728,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to Sigmoid for backward compatibility secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid); } } @@ -745,7 +736,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else if (layer_config.find("gated") != layer_config.end()) { - // Backward compatibility: convert old "gated" boolean to new enum bool gated = layer_config["gated"]; GatingMode gating_mode = gated ? GatingMode::GATED : GatingMode::NONE; gating_modes.resize(num_layers, gating_mode); @@ -763,7 +753,6 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st } else { - // Default to NONE for all layers gating_modes.resize(num_layers, GatingMode::NONE); secondary_activation_configs.resize(num_layers, activations::ActivationConfig{}); } @@ -792,11 +781,10 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const nlohmann::json& film_config = layer_config[key]; bool active = film_config.value("active", true); bool shift = film_config.value("shift", true); - int groups = film_config.value("groups", 1); - return nam::wavenet::_FiLMParams(active, shift, groups); + int film_groups = film_config.value("groups", 1); + return nam::wavenet::_FiLMParams(active, shift, film_groups); }; - // Parse FiLM parameters nam::wavenet::_FiLMParams conv_pre_film_params = parse_film_params("conv_pre_film"); nam::wavenet::_FiLMParams conv_post_film_params = parse_film_params("conv_post_film"); nam::wavenet::_FiLMParams input_mixin_pre_film_params = parse_film_params("input_mixin_pre_film"); @@ -806,36 +794,48 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film"); nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film"); - // Validation: if layer1x1_post_film is active, layer1x1 must also be active if (_layer1x1_post_film_params.active && !layer1x1_active) { throw std::runtime_error("Layer array " + std::to_string(i) + ": layer1x1_post_film cannot be active when layer1x1.active is false"); } - layer_array_params.push_back(nam::wavenet::LayerArrayParams( + wc.layer_array_params.push_back(nam::wavenet::LayerArrayParams( input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params, head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params, activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params)); } - const bool with_head = config.find("head") != config.end() && !config["head"].is_null(); - const float head_scale = config["head_scale"]; - if (layer_array_params.empty()) + wc.with_head = config.find("head") != config.end() && !config["head"].is_null(); + wc.head_scale = config["head_scale"]; + wc.in_channels = config.value("in_channels", 1); + + if (wc.layer_array_params.empty()) throw std::runtime_error("WaveNet config requires at least one layer array"); - // Backward compatibility: assume 1 input channel - const int in_channels = config.value("in_channels", 1); + return wc; +} - // out_channels is determined from last layer array's head_size - return std::make_unique( - in_channels, layer_array_params, head_scale, with_head, weights, std::move(condition_dsp), expectedSampleRate); +// WaveNetConfig::create() +std::unique_ptr nam::wavenet::WaveNetConfig::create(std::vector weights, double sampleRate) +{ + return std::make_unique(in_channels, layer_array_params, head_scale, with_head, + std::move(weights), std::move(condition_dsp), sampleRate); +} + +// Config parser for ConfigParserRegistry +std::unique_ptr nam::wavenet::create_config(const nlohmann::json& config, double sampleRate) +{ + auto wc = std::make_unique(); + auto parsed = parse_config_json(config, sampleRate); + *wc = std::move(parsed); + return wc; } -// Register the factory +// Register the config parser namespace { -static nam::factory::Helper _register_WaveNet("WaveNet", nam::wavenet::Factory); +static nam::ConfigParserHelper _register_WaveNet("WaveNet", nam::wavenet::create_config); } diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 63e1378..cf84bd2 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -713,12 +713,32 @@ class WaveNet : public DSP int PrewarmSamples() override { return mPrewarmSamples; }; }; -/// \brief Factory function to instantiate WaveNet from JSON configuration +/// \brief Configuration for a WaveNet model +struct WaveNetConfig : public ModelConfig +{ + int in_channels; + std::vector layer_array_params; + float head_scale; + bool with_head; + std::unique_ptr condition_dsp; + + // Move-only due to unique_ptr + WaveNetConfig() = default; + WaveNetConfig(WaveNetConfig&&) = default; + WaveNetConfig& operator=(WaveNetConfig&&) = default; + WaveNetConfig(const WaveNetConfig&) = delete; + WaveNetConfig& operator=(const WaveNetConfig&) = delete; + + std::unique_ptr create(std::vector weights, double sampleRate) override; +}; + +/// \brief Parse WaveNet configuration from JSON /// \param config JSON configuration object -/// \param weights Model weights vector /// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown) -/// \return Unique pointer to a DSP object (WaveNet instance) -std::unique_ptr Factory(const nlohmann::json& config, std::vector& weights, - const double expectedSampleRate); +/// \return WaveNetConfig +WaveNetConfig parse_config_json(const nlohmann::json& config, const double expectedSampleRate); + +/// \brief Config parser for ConfigParserRegistry +std::unique_ptr create_config(const nlohmann::json& config, double sampleRate); }; // namespace wavenet }; // namespace nam