Skip to content

Commit ae62c08

Browse files
authored
Add Reverse Function (#2449)
### What problem does this PR solve? _Add Reverse Function_ Issue link: #2033 [Issue](#2033) ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases
1 parent 3c4d5bc commit ae62c08

File tree

7 files changed

+267
-6
lines changed

7 files changed

+267
-6
lines changed

python/test_pysdk/test_select.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -1029,18 +1029,41 @@ def test_select_truncate(self, suffix):
10291029
db_obj.drop_table("test_select_truncate" + suffix, ConflictType.Ignore)
10301030
db_obj.create_table("test_select_truncate" + suffix,
10311031
{"c1": {"type": "double"},
1032-
"c2": {"type": "float"}}, ConflictType.Error)
1032+
"c2": {"type": "double"},
1033+
"c3": {"type": "float"}
1034+
}, ConflictType.Error)
10331035
table_obj = db_obj.get_table("test_select_truncate" + suffix)
10341036
table_obj.insert(
1035-
[{"c1": "2.123", "c2": "2.123"}, {"c1": "-2.123", "c2": "-2.123"}, {"c1": "2", "c2": "2"}, {"c1": "2.1", "c2":" 2.1"}])
1037+
[{"c1": "2.123", "c2": "2.123", "c3": "2.123"}, {"c1": "-2.123", "c2": "-2.123", "c3": "-2.123"}, {"c1": "2", "c2": "2", "c3": "2"}, {"c1": "2.1", "c2":" 2.1", "c3": "2.1"}])
10361038

1037-
res, extra_res = table_obj.output(["trunc(c1, 2)", "trunc(c2, 2)"]).to_df()
1039+
res, extra_res = table_obj.output(["trunc(c1, 14)", "trunc(c2, 2)", "trunc(c3, 2)"]).to_df()
10381040
print(res)
1039-
pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10"),
1040-
'(c2 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10")})
1041-
.astype({'(c1 trunc 2)': dtype('str_'), '(c2 trunc 2)': dtype('str_')}))
1041+
pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 trunc 14)': (" 2.12300000000000", " -2.12300000000000", " 2.00000000000000", " 2.10000000000000"),
1042+
'(c2 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10"),
1043+
'(c3 trunc 2)': (" 2.12", " -2.12", " 2.00", " 2.10")
1044+
})
1045+
.astype({'(c1 trunc 14)': dtype('str_'), '(c2 trunc 2)': dtype('str_'), '(c3 trunc 2)': dtype('str_')}))
10421046

10431047

10441048
res = db_obj.drop_table("test_select_truncate" + suffix)
10451049
assert res.error_code == ErrorCode.OK
10461050

1051+
1052+
def test_select_reverse(self, suffix):
1053+
db_obj = self.infinity_obj.get_database("default_db")
1054+
db_obj.drop_table("test_select_reverse" + suffix, ConflictType.Ignore)
1055+
db_obj.create_table("test_select_reverse" + suffix,
1056+
{"c1": {"type": "varchar", "constraints": ["primary key", "not null"]},
1057+
"c2": {"type": "varchar", "constraints": ["not null"]}}, ConflictType.Error)
1058+
table_obj = db_obj.get_table("test_select_reverse" + suffix)
1059+
table_obj.insert(
1060+
[{"c1": 'abc', "c2": 'ABC'}, {"c1": 'a123', "c2": 'a123'}, {"c1": 'c', "c2": 'C'}, {"c1": 'abcdefghijklmn', "c2": 'ABCDEFGHIJKLMN'}])
1061+
1062+
res, extra_res = table_obj.output(["reverse(c1)", "reverse(c2)"]).to_df()
1063+
print(res)
1064+
pd.testing.assert_frame_equal(res, pd.DataFrame({'reverse(c1)': ('cba', '321a', 'c', 'nmlkjihgfedcba'),
1065+
'reverse(c2)': ('CBA', '321a', 'C', 'NMLKJIHGFEDCBA')})
1066+
.astype({'reverse(c1)': dtype('str_'), 'reverse(c2)': dtype('str_')}))
1067+
1068+
res = db_obj.drop_table("test_select_reverse" + suffix)
1069+
assert res.error_code == ErrorCode.OK

src/function/builtin_functions.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ import md5;
5959
import lower;
6060
import upper;
6161
import regex;
62+
import reverse;
6263
import ltrim;
6364
import rtrim;
6465
import trim;
@@ -142,6 +143,7 @@ void BuiltinFunctions::RegisterScalarFunction() {
142143
RegisterLowerFunction(catalog_ptr_);
143144
RegisterUpperFunction(catalog_ptr_);
144145
RegisterRegexFunction(catalog_ptr_);
146+
RegisterReverseFunction(catalog_ptr_);
145147
RegisterLtrimFunction(catalog_ptr_);
146148
RegisterRtrimFunction(catalog_ptr_);
147149
RegisterTrimFunction(catalog_ptr_);

src/function/scalar/reverse.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
module;
16+
module reverse;
17+
import stl;
18+
import catalog;
19+
import status;
20+
import logical_type;
21+
import infinity_exception;
22+
import scalar_function;
23+
import scalar_function_set;
24+
import third_party;
25+
import internal_types;
26+
import data_type;
27+
import column_vector;
28+
29+
namespace infinity {
30+
31+
struct ReverseFunction {
32+
template <typename TA, typename TB, typename TC, typename TD>
33+
static inline void Run(TA &left, TB &result, TC left_ptr, TD result_ptr) {
34+
Status status = Status::NotSupport("Not implemented");
35+
RecoverableError(status);
36+
}
37+
38+
};
39+
40+
template <>
41+
inline void ReverseFunction::Run(VarcharT &left, VarcharT &result, ColumnVector *left_ptr, ColumnVector *result_ptr) {
42+
Span<const char> left_v = left_ptr->GetVarcharInner(left);
43+
const char *input = left_v.data();
44+
SizeT input_len = left_v.size();
45+
String reversed_str(input, input_len);
46+
std::reverse(reversed_str.begin(), reversed_str.end());
47+
result_ptr->AppendVarcharInner(reversed_str, result);
48+
}
49+
50+
void RegisterReverseFunction(const UniquePtr<Catalog> &catalog_ptr) {
51+
String func_name = "reverse";
52+
53+
SharedPtr<ScalarFunctionSet> function_set_ptr = MakeShared<ScalarFunctionSet>(func_name);
54+
55+
ScalarFunction resverse_function(func_name,
56+
{DataType(LogicalType::kVarchar)},
57+
{DataType(LogicalType::kVarchar)},
58+
&ScalarFunction::UnaryFunctionVarlenToVarlen<VarcharT, VarcharT, ReverseFunction>);
59+
function_set_ptr->AddFunction(resverse_function);
60+
61+
62+
Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr);
63+
}
64+
65+
} // namespace infinity

src/function/scalar/reverse.cppm

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module;
2+
3+
export module reverse;
4+
5+
import stl;
6+
7+
namespace infinity {
8+
9+
class Catalog;
10+
export void RegisterReverseFunction(const UniquePtr<Catalog> &catalog_ptr);
11+
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
16+
#include "gtest/gtest.h"
17+
#include <string_view>
18+
19+
import stl;
20+
import base_test;
21+
import infinity_exception;
22+
import infinity_context;
23+
24+
import catalog;
25+
import logger;
26+
27+
import default_values;
28+
import value;
29+
30+
import base_expression;
31+
import column_expression;
32+
import column_vector;
33+
import data_block;
34+
35+
import function_set;
36+
import function;
37+
38+
import global_resource_usage;
39+
40+
import data_type;
41+
import internal_types;
42+
import logical_type;
43+
44+
import scalar_function;
45+
import scalar_function_set;
46+
47+
import reverse;
48+
import third_party;
49+
50+
using namespace infinity;
51+
52+
class ReverseFunctionsTest : public BaseTestParamStr {};
53+
54+
INSTANTIATE_TEST_SUITE_P(TestWithDifferentParams, ReverseFunctionsTest, ::testing::Values(BaseTestParamStr::NULL_CONFIG_PATH));
55+
56+
TEST_P(ReverseFunctionsTest, reverse_func) {
57+
using namespace infinity;
58+
59+
UniquePtr<Catalog> catalog_ptr = MakeUnique<Catalog>();
60+
61+
RegisterReverseFunction(catalog_ptr);
62+
63+
String op = "reverse";
64+
65+
SharedPtr<FunctionSet> function_set = Catalog::GetFunctionSetByName(catalog_ptr.get(), op);
66+
EXPECT_EQ(function_set->type_, FunctionType::kScalar);
67+
SharedPtr<ScalarFunctionSet> scalar_function_set = std::static_pointer_cast<ScalarFunctionSet>(function_set);
68+
69+
{
70+
Vector<SharedPtr<BaseExpression>> inputs;
71+
72+
DataType data_type(LogicalType::kVarchar);
73+
SharedPtr<DataType> result_type = MakeShared<DataType>(LogicalType::kVarchar);
74+
SharedPtr<ColumnExpression> col_expr_ptr = MakeShared<ColumnExpression>(data_type, "t1", 1, "c1", 0, 0);
75+
76+
inputs.emplace_back(col_expr_ptr);
77+
78+
ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs);
79+
EXPECT_STREQ("reverse(Varchar)->Varchar", func.ToString().c_str());
80+
81+
Vector<SharedPtr<DataType>> column_types;
82+
column_types.emplace_back(MakeShared<DataType>(data_type));
83+
84+
SizeT row_count = DEFAULT_VECTOR_SIZE;
85+
86+
DataBlock data_block;
87+
data_block.Init(column_types);
88+
89+
for (SizeT i = 0; i < row_count; ++i) {
90+
data_block.AppendValue(0, Value::MakeVarchar(std::to_string(i)));
91+
}
92+
data_block.Finalize();
93+
94+
for (SizeT i = 0; i < row_count; ++i) {
95+
Value v1 = data_block.GetValue(0, i);
96+
EXPECT_EQ(v1.type_.type(), LogicalType::kVarchar);
97+
}
98+
99+
SharedPtr<ColumnVector> result = MakeShared<ColumnVector>(result_type);
100+
result->Initialize();
101+
func.function_(data_block, result);
102+
103+
for (SizeT i = 0; i < row_count; ++i) {
104+
Value v = result->GetValue(i);
105+
EXPECT_EQ(v.type_.type(), LogicalType::kVarchar);
106+
}
107+
}
108+
}

test/sql/dql/type/truncate.slt

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
statement ok
2+
DROP TABLE IF EXISTS test_truncate;
3+
4+
statement ok
5+
CREATE TABLE test_truncate (c1 integer, c2 double, c3 float);
6+
7+
# insert
8+
9+
statement ok
10+
INSERT INTO test_truncate VALUES (1, 2.4, 2.4), (2, -2.4, -2.4), (3, 2.5, 2.5), (4, -2.5, -2.5);
11+
12+
query I
13+
SELECT c2, trunc(c2, 12) FROM test_truncate;
14+
----
15+
2.400000 2.400000000000
16+
-2.400000 -2.400000000000
17+
2.500000 2.500000000000
18+
-2.500000 -2.500000000000
19+
20+
query II
21+
SELECT c3, trunc(c3, 2) FROM test_truncate;
22+
----
23+
2.400000 2.40
24+
-2.400000 -2.40
25+
2.500000 2.50
26+
-2.500000 -2.50
27+
28+
query III
29+
SELECT trunc(c2, 2) FROM test_truncate;
30+
----
31+
2.40
32+
-2.40
33+
2.50
34+
-2.50
35+
36+
statement ok
37+
DROP TABLE test_truncate;

test/sql/dql/type/varchar.slt

+14
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,19 @@ abcddddc abcddddd 2 1
143143
abcdddde abcddddd 3 1
144144
abcdddde abcdddde 4 1
145145

146+
query XXI
147+
SELECT *, reverse(c1) FROM test_varchar_filter where c1 = 'abcddddc';
148+
----
149+
abcddddc abcddddd 2 cddddcba
150+
151+
statement ok
152+
INSERT INTO test_varchar_filter VALUES ('ABCDEFGHIJKLMN', 'ABCDEFGHIJKLMN', 10);
153+
154+
query XXII
155+
SELECT *, reverse(c1) FROM test_varchar_filter where char_length(c1) > 13;
156+
----
157+
158+
ABCDEFGHIJKLMN ABCDEFGHIJKLMN 10 NMLKJIHGFEDCBA
159+
146160
statement ok
147161
DROP TABLE test_varchar_filter;

0 commit comments

Comments
 (0)