diff --git a/tmva/sofie/inc/TMVA/ROperator_Swish.hxx b/tmva/sofie/inc/TMVA/ROperator_Swish.hxx index 21d1ac5310d59..2bb8b8df41a76 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Swish.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Swish.hxx @@ -11,7 +11,6 @@ namespace TMVA{ namespace Experimental{ namespace SOFIE{ -template class ROperator_Swish final : public ROperator { diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 33eb135bcdbd7..44a8e63b77c73 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -106,6 +106,7 @@ constexpr auto modelDataSuffix = "_FromONNX.dat"; #include "input_models/references/RangeFloat.ref.hxx" #include "input_models/references/RangeInt.ref.hxx" #include "input_models/references/Tile5D.ref.hxx" +#include "input_models/references/Swish.ref.hxx" #include "gtest/gtest.h" @@ -3279,3 +3280,23 @@ TEST(ONNX, Gelu) EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); } } + +TEST(ONNX, Swish) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Input spanning negative and positive values + std::vector input{1.0, -2.0, 3.0, 0.5, -1.0, 2.0}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "Swish", input); + + // Checking output size + EXPECT_EQ(output.size(), std::size(Swish_ExpectedOutput::outputs)); + + float *correct = Swish_ExpectedOutput::outputs; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } +} diff --git a/tmva/sofie/test/input_models/Swish.onnx b/tmva/sofie/test/input_models/Swish.onnx new file mode 100644 index 0000000000000..c2eff9dfdd845 Binary files /dev/null and b/tmva/sofie/test/input_models/Swish.onnx differ diff --git a/tmva/sofie/test/input_models/references/Swish.ref.hxx b/tmva/sofie/test/input_models/references/Swish.ref.hxx new file mode 100644 index 0000000000000..d7fc0427c4b49 --- /dev/null +++ b/tmva/sofie/test/input_models/references/Swish.ref.hxx @@ -0,0 +1,3 @@ +namespace Swish_ExpectedOutput { + float outputs[] = {0.731059f, -0.238406f, 2.857723f, 0.311230f, -0.268941f, 1.761594f}; +} // namespace Swish_ExpectedOutput diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 4814b62c6ec51..2530b5b8e2ff4 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -47,6 +47,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParseSelu.cxx src/ParseShape.cxx src/ParseSigmoid.cxx + src/ParseSwish.cxx src/ParseSlice.cxx src/ParseSoftmax.cxx src/ParseTanh.cxx diff --git a/tmva/sofie_parsers/src/ParseSwish.cxx b/tmva/sofie_parsers/src/ParseSwish.cxx new file mode 100644 index 0000000000000..366b410fd5f2c --- /dev/null +++ b/tmva/sofie_parsers/src/ParseSwish.cxx @@ -0,0 +1,48 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_Swish.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseSwish = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Swish op has input tensor" + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + + // alpha attribute for Swish + float attr_alpha = 1; + + for (int_t i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "alpha") + attr_alpha = nodeproto.attribute(i).f(); + } + // ROperator_Swish implements alpha = 1 (x * sigmoid(x)); reject other values + if (attr_alpha != 1.0) { + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Swish does not yet support alpha != 1"); + } + + std::string output_name = nodeproto.output(0); + + op.reset(new ROperator_Swish(input_name, output_name)); + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index aa196c510ad60..6502f06410564 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -67,6 +67,7 @@ extern ParserFuncSignature ParseLeakyRelu; extern ParserFuncSignature ParseGelu; extern ParserFuncSignature ParseSelu; extern ParserFuncSignature ParseSigmoid; +extern ParserFuncSignature ParseSwish; extern ParserFuncSignature ParseGemm; extern ParserFuncSignature ParseRNN; extern ParserFuncSignature ParseLSTM; @@ -312,6 +313,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Selu", ParseSelu); RegisterOperator("Shape", ParseShape); RegisterOperator("Sigmoid", ParseSigmoid); + RegisterOperator("Swish", ParseSwish); RegisterOperator("Slice", ParseSlice); RegisterOperator("Softmax", ParseSoftmax); RegisterOperator("LogSoftmax", ParseSoftmax);