Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8()}, utf8(),
kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8",
NativeFunction::kNeedsContext |
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("regexp_extract", {}, DataTypeVector{utf8(), utf8(), int32()},
utf8(), kResultNullIfNull, "gdv_fn_regexp_extract_utf8_utf8_int32",
NativeFunction::kNeedsContext |
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/gandiva/gdv_string_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ const char* gdv_fn_regexp_replace_utf8_utf8(
out_length);
}

const char* gdv_fn_regexp_extract_utf8_utf8(int64_t ptr, int64_t holder_ptr,
const char* data, int32_t data_len,
const char* /*pattern*/,
int32_t /*pattern_len*/,
int32_t* out_length) {
gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);
gandiva::ExtractHolder* holder = reinterpret_cast<gandiva::ExtractHolder*>(holder_ptr);
return (*holder)(context, data, data_len, 1, out_length);
}

const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_ptr,
const char* data, int32_t data_len,
const char* /*pattern*/,
Expand Down Expand Up @@ -855,6 +865,19 @@ arrow::Status ExportedStringFunctions::AddMappings(Engine* engine) const {
"gdv_fn_regexp_extract_utf8_utf8_int32", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_regexp_extract_utf8_utf8_int32));

// gdv_fn_regexp_extract_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i64_type(), // int64_t holder_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int data_len
types->i8_ptr_type(), // const char* pattern
types->i32_type(), // int pattern_len
types->i32_ptr_type()}; // int32_t* out_length

engine->AddGlobalMappingForFunc(
"gdv_fn_regexp_extract_utf8_utf8", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_regexp_extract_utf8_utf8));

// gdv_fn_castVARCHAR_int32_int64
args = {types->i64_type(), // int64_t execution_context
types->i32_type(), // int32_t value
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/regex_functions_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ void ReplaceHolder::return_error(ExecutionContext* context, std::string& data,
}

Result<std::shared_ptr<ExtractHolder>> ExtractHolder::Make(const FunctionNode& node) {
ARROW_RETURN_IF(node.children().size() != 3,
Status::Invalid("'extract' function requires three parameters"));
ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3,
Status::Invalid("'extract' function requires two or three parameters"));

auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
ARROW_RETURN_IF(
Expand Down
141 changes: 137 additions & 4 deletions cpp/src/gandiva/regex_functions_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,24 +604,100 @@ TEST_F(TestExtractHolder, TestExtractInvalidPattern) {
execution_context_.Reset();
}

TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) {
// Create function with incorrect number of params
TEST_F(TestExtractHolder, TestEmptyInput) {
EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"((\w+))"));
auto& extract = *extract_holder;
int32_t out_length = 0;

const char* ret = extract(&execution_context_, "", 0, 0, &out_length);
EXPECT_EQ(std::string(ret, out_length), "");
EXPECT_FALSE(execution_context_.has_error());
}

TEST_F(TestExtractHolder, TestOptionalGroup) {
// (a)?(b): group 1 is optional; when input is "b" it doesn't participate
EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"((a)?(b))"));
auto& extract = *extract_holder;
int32_t out_length = 0;

std::string input = "b";
const char* ret = extract(&execution_context_, input.c_str(),
static_cast<int32_t>(input.size()), 1, &out_length);
EXPECT_EQ(std::string(ret, out_length), "");
EXPECT_FALSE(execution_context_.has_error());

ret = extract(&execution_context_, input.c_str(), static_cast<int32_t>(input.size()), 2,
&out_length);
EXPECT_EQ(std::string(ret, out_length), "b");

input = "ab";
ret = extract(&execution_context_, input.c_str(), static_cast<int32_t>(input.size()), 1,
&out_length);
EXPECT_EQ(std::string(ret, out_length), "a");
}

TEST_F(TestExtractHolder, TestNoUserGroups) {
// Pattern with no user capturing groups — only the outer wrapper group exists.
// Index 0 returns the full match; index 1 is out of range.
EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(R"(\d+)"));
auto& extract = *extract_holder;
int32_t out_length = 0;

std::string input = "abc123def";
const char* ret = extract(&execution_context_, input.c_str(),
static_cast<int32_t>(input.size()), 0, &out_length);
EXPECT_EQ(std::string(ret, out_length), "123");
EXPECT_FALSE(execution_context_.has_error());

ret = extract(&execution_context_, input.c_str(), static_cast<int32_t>(input.size()), 1,
&out_length);
EXPECT_EQ(out_length, 0);
EXPECT_TRUE(execution_context_.has_error());
execution_context_.Reset();
}

TEST_F(TestExtractHolder, TestDefaultIndexExtract) {
// 2-arg form defaults to index 1 (first capture group)
auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
auto pattern_node = std::make_shared<LiteralNode>(
arrow::utf8(), LiteralHolder(R"((\w+) (\w+))"), false);
auto function_node =
FunctionNode("regexp_extract", {field, pattern_node}, arrow::utf8());

EXPECT_OK_AND_ASSIGN(auto extract_holder, ExtractHolder::Make(function_node));

std::string input_string = "John Doe";
int32_t out_length = 0;

auto& extract = *extract_holder;
const char* ret = extract(&execution_context_, input_string.c_str(),
static_cast<int32_t>(input_string.length()), 1, &out_length);
EXPECT_EQ(std::string(ret, out_length), "John");

input_string = "Ringo Beast";
ret = extract(&execution_context_, input_string.c_str(),
static_cast<int32_t>(input_string.length()), 1, &out_length);
Comment thread
lriggs marked this conversation as resolved.
EXPECT_EQ(std::string(ret, out_length), "Ringo");
}

TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) {
// Create function with incorrect number of params (one arg)
auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
NodeVector one_arg = {field};
auto function_node = FunctionNode("regexp_extract", one_arg, arrow::utf8());

auto extract_holder = ExtractHolder::Make(function_node);
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("'extract' function requires three parameters"),
Invalid,
::testing::HasSubstr("'extract' function requires two or three parameters"),
extract_holder.status());

execution_context_.Reset();

// Create function with non-utf8 literal parameter as pattern
field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
pattern_node = std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(2), false);
auto pattern_node =
std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(2), false);
auto index_node = std::make_shared<FieldNode>(arrow::field("idx", arrow::int32()));
function_node =
FunctionNode("regexp_extract", {field, pattern_node, index_node}, arrow::utf8());
Expand Down Expand Up @@ -654,3 +730,60 @@ TEST_F(TestExtractHolder, TestErrorWhileBuildingHolder) {
}

} // namespace gandiva

extern "C" const char* gdv_fn_regexp_extract_utf8_utf8(int64_t ptr, int64_t holder_ptr,
const char* data, int32_t data_len,
const char* pattern,
int32_t pattern_len,
int32_t* out_length);

TEST(TestRegexpExtractStub, TestDefaultIndexStub) {
gandiva::ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_OK_AND_ASSIGN(auto holder, gandiva::ExtractHolder::Make(R"((\w+) (\w+))"));
auto holder_ptr = reinterpret_cast<int64_t>(holder.get());

std::string pattern = R"((\w+) (\w+))";
int32_t out_length = 0;

std::string input = "John Doe";
const char* ret = gdv_fn_regexp_extract_utf8_utf8(
ctx_ptr, holder_ptr, input.c_str(), static_cast<int32_t>(input.size()),
pattern.c_str(), static_cast<int32_t>(pattern.size()), &out_length);
EXPECT_EQ(std::string(ret, out_length), "John");

input = "Ringo Beast";
ret = gdv_fn_regexp_extract_utf8_utf8(
ctx_ptr, holder_ptr, input.c_str(), static_cast<int32_t>(input.size()),
pattern.c_str(), static_cast<int32_t>(pattern.size()), &out_length);
EXPECT_EQ(std::string(ret, out_length), "Ringo");

// no match returns empty string
input = "--- ---";
ret = gdv_fn_regexp_extract_utf8_utf8(
ctx_ptr, holder_ptr, input.c_str(), static_cast<int32_t>(input.size()),
pattern.c_str(), static_cast<int32_t>(pattern.size()), &out_length);
EXPECT_EQ(out_length, 0);
}

extern "C" const char* gdv_fn_regexp_extract_utf8_utf8_int32(
int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len,
const char* pattern, int32_t pattern_len, int32_t extract_index, int32_t* out_length);

TEST(TestRegexpExtractStub, TestIndexStub) {
gandiva::ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_OK_AND_ASSIGN(auto holder, gandiva::ExtractHolder::Make(R"((\w+) (\w+))"));
auto holder_ptr = reinterpret_cast<int64_t>(holder.get());

std::string pattern = R"((\w+) (\w+))";
int32_t out_length = 0;

std::string input = "John Doe";
const char* ret = gdv_fn_regexp_extract_utf8_utf8_int32(
ctx_ptr, holder_ptr, input.c_str(), static_cast<int32_t>(input.size()),
pattern.c_str(), static_cast<int32_t>(pattern.size()), 2, &out_length);
EXPECT_EQ(std::string(ret, out_length), "Doe");
}
Loading