Skip to content

Commit 2a65db1

Browse files
Ami11111vsian
authored andcommitted
Support Regex function (infiniflow#2059)
### What problem does this PR solve? Support Regex function ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases
1 parent 7c7fbc7 commit 2a65db1

File tree

7 files changed

+187
-1
lines changed

7 files changed

+187
-1
lines changed

example/functions.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
table_obj.insert(
2020
[{"c1": 'a', "c2": 'a'}, {"c1": 'b', "c2": 'b'}, {"c1": 'c', "c2": 'c'}, {"c1": 'd', "c2": 'd'},
2121
{"c1": 'abc', "c2": 'abc'}, {"c1": 'bbc', "c2": 'bbc'}, {"c1": 'cbc', "c2": 'cbc'}, {"c1": 'dbc', "c2": 'dbc'},
22-
{"c1": 'abcd', "c2": 'abc'}])
22+
{"c1": 'abcd', "c2": 'abc'},
23+
{"c1": '[email protected]', "c2": 'email'}, {"c1": '[email protected]', "c2": 'email'}])
2324

2425
#function char_length
2526
res = table_obj.output(["*"]).filter("char_length(c1) = 1").to_df()
@@ -34,6 +35,13 @@
3435
res = table_obj.output(["*"]).filter("char_length(c1) = char_length(c2)").to_df()
3536
print(res)
3637

38+
#functin regex
39+
res = table_obj.output(["*"]).filter("regex(c1, 'bc')").to_df()
40+
print(res)
41+
42+
res = table_obj.output(["*"]).filter("regex(c1, '(\w+([-+.]\w+)*)@(\w+([-.]\w+)*)\.(\w+([-.]\w+)*)')").to_df()
43+
print(res)
44+
3745
res = db_obj.drop_table("function_example")
3846

3947
infinity_obj.disconnect()

example/http/functions.sh

+69
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,30 @@ curl --request POST \
8787
"sparse_column": {"20":7.7, "80":7.8, "90": 97.9},
8888
"year": 2018,
8989
"tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]]
90+
},
91+
{
92+
"num": 5,
93+
"body": "[email protected]",
94+
"vec": [4.0, 4.2, 4.3, 4.5],
95+
"sparse_column": {"20":7.7, "80":7.8, "90": 97.9},
96+
"year": 2018,
97+
"tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]]
98+
},
99+
{
100+
"num": 6,
101+
"body": "test@hotmailcom",
102+
"vec": [4.0, 4.2, 4.3, 4.5],
103+
"sparse_column": {"20":7.7, "80":7.8, "90": 97.9},
104+
"year": 2018,
105+
"tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]]
106+
},
107+
{
108+
"num": 7,
109+
"body": "this is a sentence including a mail address, [email protected]",
110+
"vec": [4.0, 4.2, 4.3, 4.5],
111+
"sparse_column": {"20":7.7, "80":7.8, "90": 97.9},
112+
"year": 2018,
113+
"tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]]
90114
}
91115
] '
92116

@@ -134,6 +158,51 @@ curl --request GET \
134158
"filter": "char_length(body) > 4"
135159
} '
136160

161+
# show rows of 'tbl1' where body inluding 'com'
162+
echo -e '\n\n-- show rows of 'tbl1' where body inluding '\'com\'''
163+
curl --request GET \
164+
--url http://localhost:23820/databases/default_db/tables/tbl1/docs \
165+
--header 'accept: application/json' \
166+
--header 'content-type: application/json' \
167+
--data '
168+
{
169+
"output":
170+
[
171+
"body"
172+
],
173+
"filter": "regex(body, '\'com\'')"
174+
} '
175+
176+
# show rows of 'tbl1' where body including a mail address (using regex)
177+
echo -e '\n\n-- show rows of 'tbl1' where body including a mail address (using regex)'
178+
curl --request GET \
179+
--url http://localhost:23820/databases/default_db/tables/tbl1/docs \
180+
--header 'accept: application/json' \
181+
--header 'content-type: application/json' \
182+
--data '
183+
{
184+
"output":
185+
[
186+
"body"
187+
],
188+
"filter": "regex(body, '\''('.'*'')'@'('.'*'')''\\''.'com\'')"
189+
} '
190+
191+
# show rows of 'tbl1' where body including a mail address (using regex)
192+
echo -e '\n\n-- show rows of 'tbl1' where body including a mail address (using regex)'
193+
curl --request GET \
194+
--url http://localhost:23820/databases/default_db/tables/tbl1/docs \
195+
--header 'accept: application/json' \
196+
--header 'content-type: application/json' \
197+
--data '
198+
{
199+
"output":
200+
[
201+
"body"
202+
],
203+
"filter": "regex(body, '\''('[0-9A-Za-z_]+'('[-+.][0-9A-Za-z_]+')''*'')'@'('[0-9A-Za-z_]+'('[-.][0-9A-Za-z_]+')''*'')''\\'.'('[0-9A-Za-z_]+'('[-.][0-9A-Za-z_]+')''*'')'\'')"
204+
} '
205+
137206
# drop tbl1
138207
echo -e '\n\n-- drop tbl1'
139208
curl --request DELETE \

python/test_pysdk/test_select.py

+21
Original file line numberDiff line numberDiff line change
@@ -831,4 +831,25 @@ def test_select_varchar_length(self, suffix):
831831
.astype({'c1': dtype('O'), 'c2': dtype('O')}))
832832

833833
res = db_obj.drop_table("test_select_varchar_length"+suffix)
834+
assert res.error_code == ErrorCode.OK
835+
836+
def test_select_regex(self, suffix):
837+
db_obj = self.infinity_obj.get_database("default_db")
838+
db_obj.drop_table("test_select_regex"+suffix, ConflictType.Ignore)
839+
db_obj.create_table("test_select_regex"+suffix,
840+
{"c1": {"type": "varchar", "constraints": ["primary key", "not null"]},
841+
"c2": {"type": "varchar", "constraints": ["not null"]}}, ConflictType.Error)
842+
table_obj = db_obj.get_table("test_select_regex"+suffix)
843+
table_obj.insert(
844+
[{"c1": 'a', "c2": 'a'}, {"c1": 'b', "c2": 'b'}, {"c1": 'c', "c2": 'c'}, {"c1": 'd', "c2": 'd'},
845+
{"c1": 'abc', "c2": 'abc'}, {"c1": 'bbc', "c2": 'bbc'}, {"c1": 'cbc', "c2": 'cbc'}, {"c1": 'dbc', "c2": 'dbc'},])
846+
847+
res = table_obj.output(["*"]).filter("regex(c1, 'bc')").to_df()
848+
print(res)
849+
pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ('abc', 'bbc', 'cbc', 'dbc'),
850+
'c2': ('abc', 'bbc', 'cbc', 'dbc')})
851+
.astype({'c1': dtype('O'), 'c2': dtype('O')}))
852+
853+
854+
res = db_obj.drop_table("test_select_regex"+suffix)
834855
assert res.error_code == ErrorCode.OK

src/function/builtin_functions.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import substring;
4848
import substract;
4949
import char_length;
5050
import md5;
51+
import regex;
5152
import default_values;
5253
import special_function;
5354
import internal_types;
@@ -117,6 +118,7 @@ void BuiltinFunctions::RegisterScalarFunction() {
117118
RegisterSubstringFunction(catalog_ptr_);
118119
RegisterCharLengthFunction(catalog_ptr_);
119120
RegisterMd5Function(catalog_ptr_);
121+
RegisterRegexFunction(catalog_ptr_);
120122
}
121123

122124
void BuiltinFunctions::RegisterTableFunction() {}

src/function/scalar/regex.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module;
2+
3+
#include <re2/re2.h>
4+
5+
module regex;
6+
7+
import stl;
8+
import catalog;
9+
import status;
10+
import infinity_exception;
11+
import scalar_function;
12+
import scalar_function_set;
13+
14+
import third_party;
15+
import logical_type;
16+
import internal_types;
17+
import data_type;
18+
import logger;
19+
import column_vector;
20+
21+
namespace infinity {
22+
23+
struct RegexFunction {
24+
template <typename TA, typename TB, typename TC>
25+
static inline void Run(TA &left, TB &right, TC &result) {
26+
const char * origin_str;
27+
SizeT origin_len;
28+
const char * pattern_str;
29+
SizeT pattern_len;
30+
GetReaderValue(left, origin_str, origin_len);
31+
GetReaderValue(right, pattern_str, pattern_len);
32+
String origin_str_(origin_str, origin_len);
33+
String pattern_str_(pattern_str, pattern_len);
34+
bool match = re2::RE2::PartialMatch(origin_str_, pattern_str_);
35+
result.SetValue(match);
36+
}
37+
};
38+
39+
40+
void RegisterRegexFunction(const UniquePtr<Catalog> &catalog_ptr){
41+
String func_name = "regex";
42+
43+
SharedPtr<ScalarFunctionSet> function_set_ptr = MakeShared<ScalarFunctionSet>(func_name);
44+
45+
ScalarFunction Regex_function(func_name,
46+
{DataType(LogicalType::kVarchar), DataType(LogicalType::kVarchar)},
47+
DataType(LogicalType::kBoolean),
48+
&ScalarFunction::BinaryFunction<VarcharT, VarcharT, BooleanT, RegexFunction>);
49+
function_set_ptr->AddFunction(Regex_function);
50+
51+
Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr);
52+
}
53+
54+
}

src/function/scalar/regex.cppm

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module;
2+
3+
import stl;
4+
5+
export module regex;
6+
7+
namespace infinity {
8+
9+
class Catalog;
10+
11+
export void RegisterRegexFunction(const UniquePtr<Catalog> &catalog_ptr);
12+
13+
}

test/sql/dql/type/varchar.slt

+19
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,24 @@ SELECT * FROM test_varchar_filter where md5(c1) = md5('abcdddde');
6868
abcdddde abcddddd 3
6969
abcdddde abcdddde 4
7070

71+
statement ok
72+
INSERT INTO test_varchar_filter VALUES ('[email protected]', '[email protected]', 6);
73+
74+
query X
75+
SELECT * FROM test_varchar_filter where regex(c1, 'abc\w+e');
76+
----
77+
abcdddde abcddddd 3
78+
abcdddde abcdddde 4
79+
80+
query XI
81+
SELECT * FROM test_varchar_filter where regex(c1, 'ddddc');
82+
----
83+
abcddddc abcddddd 2
84+
85+
query XII
86+
SELECT * FROM test_varchar_filter where regex(c1, '(\w+([-+.]\w+)*)@(\w+([-.]\w+)*)\.(\w+([-.]\w+)*)');
87+
----
88+
89+
7190
statement ok
7291
DROP TABLE test_varchar_filter;

0 commit comments

Comments
 (0)