Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tmva/sofie/inc/TMVA/ROperator_Swish.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace TMVA{
namespace Experimental{
namespace SOFIE{

template <typename T>
class ROperator_Swish final : public ROperator
{

Expand Down
21 changes: 21 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ constexpr auto modelDataSuffix = "_FromONNX.dat";
#include "input_models/references/RangeFloat.ref.hxx"
#include "input_models/references/RangeInt.ref.hxx"
#include "input_models/references/Tile5D.ref.hxx"
#include "input_models/references/Swish.ref.hxx"

#include "gtest/gtest.h"

Expand Down Expand Up @@ -3279,3 +3280,23 @@ TEST(ONNX, Gelu)
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, Swish)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Input spanning negative and positive values
std::vector<float> input{1.0, -2.0, 3.0, 0.5, -1.0, 2.0};

ASSERT_INCLUDE_AND_RUN(std::vector<float>, "Swish", input);

// Checking output size
EXPECT_EQ(output.size(), std::size(Swish_ExpectedOutput::outputs));

float *correct = Swish_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}
Binary file added tmva/sofie/test/input_models/Swish.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions tmva/sofie/test/input_models/references/Swish.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Swish_ExpectedOutput {
float outputs[] = {0.731059f, -0.238406f, 2.857723f, 0.311230f, -0.268941f, 1.761594f};
} // namespace Swish_ExpectedOutput
1 change: 1 addition & 0 deletions tmva/sofie_parsers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser
src/ParseSelu.cxx
src/ParseShape.cxx
src/ParseSigmoid.cxx
src/ParseSwish.cxx
src/ParseSlice.cxx
src/ParseSoftmax.cxx
src/ParseTanh.cxx
Expand Down
48 changes: 48 additions & 0 deletions tmva/sofie_parsers/src/ParseSwish.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "TMVA/RModelParser_ONNX.hxx"
#include "TMVA/ROperator_Swish.hxx"
#include "onnx_proto3.pb.h"

namespace TMVA {
namespace Experimental {
namespace SOFIE {

ParserFuncSignature ParseSwish = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
ETensorType input_type;

auto input_name = nodeproto.input(0);
if (parser.IsRegisteredTensorType(input_name)) {
input_type = parser.GetTensorType(input_name);
} else {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Swish op has input tensor" + input_name +
" but its type is not yet registered");
}

std::unique_ptr<ROperator> op;

// alpha attribute for Swish
float attr_alpha = 1;

for (int_t i = 0; i < nodeproto.attribute_size(); i++) {
std::string attribute_name = nodeproto.attribute(i).name();
if (attribute_name == "alpha")
attr_alpha = nodeproto.attribute(i).f();
}
// ROperator_Swish implements alpha = 1 (x * sigmoid(x)); reject other values
if (attr_alpha != 1.0) {
throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Swish does not yet support alpha != 1");
}

std::string output_name = nodeproto.output(0);

op.reset(new ROperator_Swish(input_name, output_name));

if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, input_type);
}

return op;
};

} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA
2 changes: 2 additions & 0 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ extern ParserFuncSignature ParseLeakyRelu;
extern ParserFuncSignature ParseGelu;
extern ParserFuncSignature ParseSelu;
extern ParserFuncSignature ParseSigmoid;
extern ParserFuncSignature ParseSwish;
extern ParserFuncSignature ParseGemm;
extern ParserFuncSignature ParseRNN;
extern ParserFuncSignature ParseLSTM;
Expand Down Expand Up @@ -312,6 +313,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("Selu", ParseSelu);
RegisterOperator("Shape", ParseShape);
RegisterOperator("Sigmoid", ParseSigmoid);
RegisterOperator("Swish", ParseSwish);
RegisterOperator("Slice", ParseSlice);
RegisterOperator("Softmax", ParseSoftmax);
RegisterOperator("LogSoftmax", ParseSoftmax);
Expand Down