Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,106 @@ else()
endif()

set(NAM_DEPS_PATH "${CMAKE_CURRENT_SOURCE_DIR}/Dependencies")
include_directories(SYSTEM "${NAM_DEPS_PATH}/eigen")

set(NAM_CORE_SOURCES
NAM/activations.cpp
NAM/conv1d.cpp
NAM/convnet.cpp
NAM/dsp.cpp
NAM/get_dsp.cpp
NAM/lstm.cpp
NAM/ring_buffer.cpp
NAM/util.cpp
NAM/wavenet.cpp
NAM/nam_c_api.cpp
)

add_library(nam_static STATIC ${NAM_CORE_SOURCES})
add_library(nam_shared SHARED ${NAM_CORE_SOURCES})
add_library(nam::static ALIAS nam_static)
add_library(nam::shared ALIAS nam_shared)

set(NAM_LIBRARY_TARGETS nam_static nam_shared)
foreach(target ${NAM_LIBRARY_TARGETS})
target_compile_features(${target} PUBLIC cxx_std_20)
if (MSVC)
target_compile_options(${target} PRIVATE /UNAM_SAMPLE_FLOAT)
else()
target_compile_options(${target} PRIVATE -UNAM_SAMPLE_FLOAT)
endif()
target_include_directories(${target}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<INSTALL_INTERFACE:include>
PRIVATE
${NAM_DEPS_PATH}/eigen
${NAM_DEPS_PATH}/nlohmann
)
set_target_properties(${target}
PROPERTIES
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN YES
POSITION_INDEPENDENT_CODE ON
)
if (MSVC)
target_compile_options(${target} PRIVATE
"$<$<CONFIG:DEBUG>:/W4>"
"$<$<CONFIG:RELEASE>:/O2>"
)
else()
target_compile_options(${target} PRIVATE
-Wall -Wextra -Wpedantic -Wstrict-aliasing -Wunreachable-code -Weffc++
-Wno-unused-parameter
"$<$<CONFIG:DEBUG>:-Og;-ggdb;-Werror>"
"$<$<CONFIG:RELEASE>:-Ofast>"
)
endif()
endforeach()

target_compile_definitions(nam_shared PRIVATE NAM_BUILD_SHARED)
set_target_properties(nam_static PROPERTIES OUTPUT_NAME nam)
set_target_properties(nam_shared PROPERTIES
OUTPUT_NAME nam
VERSION ${PROJECT_VERSION}
SOVERSION ${PROJECT_VERSION_MAJOR}
)

if (CMAKE_SYSTEM_NAME STREQUAL "Windows")
set_target_properties(nam_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
target_compile_definitions(nam_static PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN)
target_compile_definitions(nam_shared PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN)
endif()

# There's an error in eigen's GeneralBlockPanelKernel.h in some debug builds.
set_source_files_properties(NAM/dsp.cpp PROPERTIES COMPILE_FLAGS "-Wno-error")
set_source_files_properties(NAM/conv1d.cpp PROPERTIES COMPILE_FLAGS "-Wno-error")

install(TARGETS nam_static nam_shared EXPORT NAMTargets
ARCHIVE DESTINATION lib
LIBRARY DESTINATION lib
RUNTIME DESTINATION bin
COMPONENT nam
)
install(FILES NAM/nam_c_api.h DESTINATION include/NAM COMPONENT nam)
install(EXPORT NAMTargets
FILE NAMTargets.cmake
NAMESPACE nam::
DESTINATION lib/cmake/NAM
COMPONENT nam
)

add_subdirectory(tools)

add_custom_target(install_nam
COMMAND ${CMAKE_COMMAND} --install ${CMAKE_BINARY_DIR} --config $<CONFIG> --component nam
COMMENT "Install NAM libraries and C API"
)

add_custom_target(install_tools
COMMAND ${CMAKE_COMMAND} --install ${CMAKE_BINARY_DIR} --config $<CONFIG> --component tools
COMMENT "Install NAM tool binaries"
)

#file(MAKE_DIRECTORY build/tools)

#add_custom_target(copy_tools ALL
Expand Down
8 changes: 6 additions & 2 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <algorithm> // std::max_element
#include <algorithm>
#include <cmath> // pow, tanh, expf
#include <mutex>
#include <filesystem>
#include <fstream>
#include <string>
Expand Down Expand Up @@ -340,7 +341,10 @@ std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, st
in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups);
}

namespace
void nam::convnet::RegisterFactory()
{
static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory);
static std::once_flag once;
std::call_once(once, []() {
nam::factory::FactoryRegistry::instance().registerFactory("ConvNet", nam::convnet::Factory);
});
}
3 changes: 3 additions & 0 deletions NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,8 @@ class ConvNet : public Buffer
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate);

/// \brief Register ConvNet factory in the global registry
void RegisterFactory();

}; // namespace convnet
}; // namespace nam
5 changes: 5 additions & 0 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
{
verify_config_version(conf.version);

// Explicit registration avoids missing factories when NAM is linked as a static library.
nam::lstm::RegisterFactory();
nam::convnet::RegisterFactory();
nam::wavenet::RegisterFactory();

auto& architecture = conf.architecture;
nlohmann::json& config = conf.config;
std::vector<float>& weights = conf.weights;
Expand Down
9 changes: 6 additions & 3 deletions NAM/lstm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <algorithm>
#include <mutex>
#include <string>
#include <vector>
#include <memory>
Expand Down Expand Up @@ -177,8 +178,10 @@ std::unique_ptr<nam::DSP> nam::lstm::Factory(const nlohmann::json& config, std::
in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate);
}

// Register the factory
namespace
void nam::lstm::RegisterFactory()
{
static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory);
static std::once_flag once;
std::call_once(once, []() {
nam::factory::FactoryRegistry::instance().registerFactory("LSTM", nam::lstm::Factory);
});
}
3 changes: 3 additions & 0 deletions NAM/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,8 @@ class LSTM : public DSP
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate);

/// \brief Register LSTM factory in the global registry
void RegisterFactory();

}; // namespace lstm
}; // namespace nam
214 changes: 214 additions & 0 deletions NAM/nam_c_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
#include "nam_c_api.h"

#include <exception>
#include <filesystem>
#include <memory>
#include <string>
#include <utility>

#include "activations.h"
#include "get_dsp.h"

struct nam_model
{
std::unique_ptr<nam::DSP> dsp;
};

namespace
{
thread_local std::string g_last_error;

void set_error(std::string message)
{
g_last_error = std::move(message);
}

nam_status_t handle_exception()
{
try
{
throw;
}
catch (const std::exception& e)
{
set_error(e.what());
}
catch (...)
{
set_error("Unknown exception");
}
return NAM_STATUS_EXCEPTION;
}
} // namespace

extern "C"
{

nam_status_t nam_create_model_from_file(const char* model_path, nam_model_t** out_model)
{
if (model_path == nullptr || out_model == nullptr)
{
set_error("Invalid argument: model_path and out_model must be non-null.");
return NAM_STATUS_INVALID_ARGUMENT;
}

*out_model = nullptr;
try
{
auto model = std::make_unique<nam_model>();
model->dsp = nam::get_dsp(std::filesystem::path(model_path));
*out_model = model.release();
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

void nam_destroy_model(nam_model_t* model)
{
delete model;
}

nam_status_t nam_reset(nam_model_t* model, double sample_rate, int max_buffer_size)
{
if (model == nullptr || model->dsp == nullptr || max_buffer_size < 0)
{
set_error("Invalid argument: model must be valid and max_buffer_size must be >= 0.");
return NAM_STATUS_INVALID_ARGUMENT;
}

try
{
model->dsp->Reset(sample_rate, max_buffer_size);
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

nam_status_t nam_process(nam_model_t* model, nam_sample_t** input, nam_sample_t** output, int num_frames)
{
if (model == nullptr || model->dsp == nullptr || input == nullptr || output == nullptr || num_frames < 0)
{
set_error("Invalid argument: model/input/output must be valid and num_frames must be >= 0.");
return NAM_STATUS_INVALID_ARGUMENT;
}

try
{
model->dsp->process(input, output, num_frames);
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

int nam_num_input_channels(const nam_model_t* model)
{
if (model == nullptr || model->dsp == nullptr)
return 0;
return model->dsp->NumInputChannels();
}

int nam_num_output_channels(const nam_model_t* model)
{
if (model == nullptr || model->dsp == nullptr)
return 0;
return model->dsp->NumOutputChannels();
}

double nam_expected_sample_rate(const nam_model_t* model)
{
if (model == nullptr || model->dsp == nullptr)
return -1.0;
return model->dsp->GetExpectedSampleRate();
}

nam_status_t nam_enable_fast_tanh(void)
{
try
{
nam::activations::Activation::enable_fast_tanh();
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

nam_status_t nam_disable_fast_tanh(void)
{
try
{
nam::activations::Activation::disable_fast_tanh();
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

int nam_is_fast_tanh_enabled(void)
{
return nam::activations::Activation::using_fast_tanh ? 1 : 0;
}

nam_status_t nam_enable_lut(const char* function_name, float min, float max, int n_points)
{
if (function_name == nullptr || n_points <= 1 || !(min < max))
{
set_error("Invalid argument: function_name must be non-null, min < max, and n_points > 1.");
return NAM_STATUS_INVALID_ARGUMENT;
}

try
{
nam::activations::Activation::enable_lut(function_name, min, max, static_cast<std::size_t>(n_points));
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

nam_status_t nam_disable_lut(const char* function_name)
{
if (function_name == nullptr)
{
set_error("Invalid argument: function_name must be non-null.");
return NAM_STATUS_INVALID_ARGUMENT;
}

try
{
nam::activations::Activation::disable_lut(function_name);
g_last_error.clear();
return NAM_STATUS_OK;
}
catch (...)
{
return handle_exception();
}
}

const char* nam_get_last_error(void)
{
return g_last_error.c_str();
}

} // extern "C"
Loading