diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index be57ce4f476..4f063d8f472 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -257,6 +257,12 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsFunctionHolder | NativeFunction::kCanReturnErrors), + NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8", + NativeFunction::kNeedsContext | + NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), + NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8(), int32()}, utf8(), kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8_int32", NativeFunction::kNeedsContext | diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index d271834fb47..7cfbecf7735 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -70,6 +70,16 @@ const char* gdv_fn_regexp_replace_utf8_utf8( out_length); } +const char* gdv_fn_regexp_extract_utf8_utf8(int64_t ptr, int64_t holder_ptr, + const char* data, int32_t data_len, + const char* /*pattern*/, + int32_t /*pattern_len*/, + int32_t* out_length) { + gandiva::ExecutionContext* context = reinterpret_cast(ptr); + gandiva::ExtractHolder* holder = reinterpret_cast(holder_ptr); + return (*holder)(context, data, data_len, 1, out_length); +} + const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len, const char* /*pattern*/, @@ -855,6 +865,19 @@ arrow::Status ExportedStringFunctions::AddMappings(Engine* engine) const { "gdv_fn_regexp_extract_utf8_utf8_int32", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_regexp_extract_utf8_utf8_int32)); + // gdv_fn_regexp_extract_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i64_type(), // int64_t holder_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i32_ptr_type()}; // int32_t* out_length + + engine->AddGlobalMappingForFunc( + "gdv_fn_regexp_extract_utf8_utf8", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_regexp_extract_utf8_utf8)); + // gdv_fn_castVARCHAR_int32_int64 args = {types->i64_type(), // int64_t execution_context types->i32_type(), // int32_t value diff --git a/cpp/src/gandiva/regex_functions_holder.cc b/cpp/src/gandiva/regex_functions_holder.cc index 6c0c3d40f12..334c640833e 100644 --- a/cpp/src/gandiva/regex_functions_holder.cc +++ b/cpp/src/gandiva/regex_functions_holder.cc @@ -212,8 +212,8 @@ void ReplaceHolder::return_error(ExecutionContext* context, std::string& data, } Result> ExtractHolder::Make(const FunctionNode& node) { - ARROW_RETURN_IF(node.children().size() != 3, - Status::Invalid("'extract' function requires three parameters")); + ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3, + Status::Invalid("'extract' function requires two or three parameters")); auto literal = dynamic_cast(node.children().at(1).get()); ARROW_RETURN_IF( diff --git a/cpp/src/gandiva/regex_functions_holder_test.cc b/cpp/src/gandiva/regex_functions_holder_test.cc index 4d7b0fd3192..a78206bb847 100644 --- a/cpp/src/gandiva/regex_functions_holder_test.cc +++ b/cpp/src/gandiva/regex_functions_holder_test.cc @@ -604,24 +604,100 @@ TEST_F(TestExtractHolder, TestExtractInvalidPattern) { execution_context_.Reset(); } -TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) { - // Create function with incorrect number of params +TEST_F(TestExtractHolder, TestEmptyInput) { + EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"((\w+))")); + auto& extract = *extract_holder; + int32_t out_length = 0; + + const char* ret = extract(&execution_context_, "", 0, 0, &out_length); + EXPECT_EQ(std::string(ret, out_length), ""); + EXPECT_FALSE(execution_context_.has_error()); +} + +TEST_F(TestExtractHolder, TestOptionalGroup) { + // (a)?(b): group 1 is optional; when input is "b" it doesn't participate + EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"((a)?(b))")); + auto& extract = *extract_holder; + int32_t out_length = 0; + + std::string input = "b"; + const char* ret = extract(&execution_context_, input.c_str(), + static_cast(input.size()), 1, &out_length); + EXPECT_EQ(std::string(ret, out_length), ""); + EXPECT_FALSE(execution_context_.has_error()); + + ret = extract(&execution_context_, input.c_str(), static_cast(input.size()), 2, + &out_length); + EXPECT_EQ(std::string(ret, out_length), "b"); + + input = "ab"; + ret = extract(&execution_context_, input.c_str(), static_cast(input.size()), 1, + &out_length); + EXPECT_EQ(std::string(ret, out_length), "a"); +} + +TEST_F(TestExtractHolder, TestNoUserGroups) { + // Pattern with no user capturing groups — only the outer wrapper group exists. + // Index 0 returns the full match; index 1 is out of range. + EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"(\d+)")); + auto& extract = *extract_holder; + int32_t out_length = 0; + + std::string input = "abc123def"; + const char* ret = extract(&execution_context_, input.c_str(), + static_cast(input.size()), 0, &out_length); + EXPECT_EQ(std::string(ret, out_length), "123"); + EXPECT_FALSE(execution_context_.has_error()); + + ret = extract(&execution_context_, input.c_str(), static_cast(input.size()), 1, + &out_length); + EXPECT_EQ(out_length, 0); + EXPECT_TRUE(execution_context_.has_error()); + execution_context_.Reset(); +} + +TEST_F(TestExtractHolder, TestDefaultIndexExtract) { + // 2-arg form defaults to index 1 (first capture group) auto field = std::make_shared(arrow::field("in", arrow::utf8())); auto pattern_node = std::make_shared( arrow::utf8(), LiteralHolder(R"((\w+) (\w+))"), false); auto function_node = FunctionNode("regexp_extract", {field, pattern_node}, arrow::utf8()); + EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(function_node)); + + std::string input_string = "John Doe"; + int32_t out_length = 0; + + auto& extract = *extract_holder; + const char* ret = extract(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), 1, &out_length); + EXPECT_EQ(std::string(ret, out_length), "John"); + + input_string = "Ringo Beast"; + ret = extract(&execution_context_, input_string.c_str(), + static_cast(input_string.length()), 1, &out_length); + EXPECT_EQ(std::string(ret, out_length), "Ringo"); +} + +TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) { + // Create function with incorrect number of params (one arg) + auto field = std::make_shared(arrow::field("in", arrow::utf8())); + NodeVector one_arg = {field}; + auto function_node = FunctionNode("regexp_extract", one_arg, arrow::utf8()); + auto extract_holder = ExtractHolder::Make(function_node); EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("'extract' function requires three parameters"), + Invalid, + ::testing::HasSubstr("'extract' function requires two or three parameters"), extract_holder.status()); execution_context_.Reset(); // Create function with non-utf8 literal parameter as pattern field = std::make_shared(arrow::field("in", arrow::utf8())); - pattern_node = std::make_shared(arrow::int32(), LiteralHolder(2), false); + auto pattern_node = + std::make_shared(arrow::int32(), LiteralHolder(2), false); auto index_node = std::make_shared(arrow::field("idx", arrow::int32())); function_node = FunctionNode("regexp_extract", {field, pattern_node, index_node}, arrow::utf8()); @@ -654,3 +730,60 @@ TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) { } } // namespace gandiva + +extern "C" const char* gdv_fn_regexp_extract_utf8_utf8(int64_t ptr, int64_t holder_ptr, + const char* data, int32_t data_len, + const char* pattern, + int32_t pattern_len, + int32_t* out_length); + +TEST(TestRegexpExtractStub, TestDefaultIndexStub) { + gandiva::ExecutionContext ctx; + auto ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_OK_AND_ASSIGN(auto holder, gandiva::ExtractHolder::Make(R"((\w+) (\w+))")); + auto holder_ptr = reinterpret_cast(holder.get()); + + std::string pattern = R"((\w+) (\w+))"; + int32_t out_length = 0; + + std::string input = "John Doe"; + const char* ret = gdv_fn_regexp_extract_utf8_utf8( + ctx_ptr, holder_ptr, input.c_str(), static_cast(input.size()), + pattern.c_str(), static_cast(pattern.size()), &out_length); + EXPECT_EQ(std::string(ret, out_length), "John"); + + input = "Ringo Beast"; + ret = gdv_fn_regexp_extract_utf8_utf8( + ctx_ptr, holder_ptr, input.c_str(), static_cast(input.size()), + pattern.c_str(), static_cast(pattern.size()), &out_length); + EXPECT_EQ(std::string(ret, out_length), "Ringo"); + + // no match returns empty string + input = "--- ---"; + ret = gdv_fn_regexp_extract_utf8_utf8( + ctx_ptr, holder_ptr, input.c_str(), static_cast(input.size()), + pattern.c_str(), static_cast(pattern.size()), &out_length); + EXPECT_EQ(out_length, 0); +} + +extern "C" const char* gdv_fn_regexp_extract_utf8_utf8_int32( + int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len, + const char* pattern, int32_t pattern_len, int32_t extract_index, int32_t* out_length); + +TEST(TestRegexpExtractStub, TestIndexStub) { + gandiva::ExecutionContext ctx; + auto ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_OK_AND_ASSIGN(auto holder, gandiva::ExtractHolder::Make(R"((\w+) (\w+))")); + auto holder_ptr = reinterpret_cast(holder.get()); + + std::string pattern = R"((\w+) (\w+))"; + int32_t out_length = 0; + + std::string input = "John Doe"; + const char* ret = gdv_fn_regexp_extract_utf8_utf8_int32( + ctx_ptr, holder_ptr, input.c_str(), static_cast(input.size()), + pattern.c_str(), static_cast(pattern.size()), 2, &out_length); + EXPECT_EQ(std::string(ret, out_length), "Doe"); +}