Skip to content

Commit 8c48501

Browse files
committed
refine based on comments
Signed-off-by: Runji Wang <[email protected]>
1 parent 3ca95d1 commit 8c48501

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

rfcs/0000-user-defined-aggregate-functions.md

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,20 @@ from risingwave.udf import udaf
2626
# It will be serialized into bytes before sending to kernel,
2727
# and deserialized from bytes after receiving from kernel.
2828
class State:
29-
sum: int
30-
count: int
29+
sum: int = 0
30+
count: int = 0
3131

3232
# The aggregate function is defined as a class.
3333
# Specify the schema of the aggregate function in the `udaf` decorator.
3434
@udaf(input_types=['BIGINT', 'INT'], result_type='BIGINT')
3535
class WeightedAvg:
3636
# Create an empty state.
3737
def create_state(self) -> State:
38-
state = State()
39-
state.sum = 0
40-
state.count = 0
41-
return state
38+
return State()
4239

4340
# Get the aggregate result.
4441
# The return value should match `result_type`.
45-
def get_value(self, state: State) -> int:
42+
def get_result(self, state: State) -> int:
4643
if state.count == 0:
4744
return None
4845
else:
@@ -106,7 +103,7 @@ public class WeightedAvg implements AggregateFunction {
106103
// Get the aggregate result.
107104
// The result type is inferred from the signature. (BIGINT)
108105
// If a Java type can not infer to an unique SQL type, it should be annotated with `@DataTypeHint`.
109-
public Long getValue(WeightedAvgState acc) {
106+
public Long getResult(WeightedAvgState acc) {
110107
if (acc.count == 0) {
111108
return null;
112109
} else {
@@ -152,7 +149,7 @@ The full syntax is:
152149

153150
```sql
154151
CREATE [ OR REPLACE ] AGGREGATE name ( [ argname ] arg_data_type [ , ... ] )
155-
[ RETURNS return_data_type ] [ APPEND ONLY ]
152+
[ RETURNS result_data_type ] [ APPEND ONLY ]
156153
[ LANGUAGE language ] [ AS identifier ] [ USING LINK link ];
157154
```
158155

@@ -206,9 +203,9 @@ sequenceDiagram
206203

207204
The state of UDAF is managed by the compute node as a single encoded BYTEA value.
208205

209-
Currently, each aggregate operator has a **result table** to store the aggregate result. For most of our built-in aggregate functions, they have the same output as their state, so the result table is actually being used as the state table. However, for general UDAFs, their state may not be the same as their output. Such functions are not supported for now.
206+
Currently, each aggregate operator has a **result table** to store the aggregate result. For most of our built-in aggregate functions, they have the same output as their state, so the result table is actually being used as the **intermediate state table**. However, for general UDAFs, their state may not be the same as their output. Such functions are not supported for now.
210207

211-
Therefore, we propose to **transform the result table into state table**. The content of the table remains the same for existing functions. But for new functions whose state is different from output, only the state is stored. The output can be computed from the state when needed.
208+
Therefore, we propose to **transform the result table into intermediate state table**. The content of the table remains the same for existing functions. But for new functions whose state is different from output, only the state is stored. The output can be computed from the state when needed.
212209

213210
For example, given the input:
214211

@@ -218,14 +215,14 @@ For example, given the input:
218215
| 2 | U- | 0 | 1 | 2 | false |
219216
| 2 | U+ | 0 | 2 | 1 | true |
220217

221-
The new **state table** (derived from old result table) of the agg operator would be like:
218+
The new **intermediate state table** (migrated from the old result table) of the agg operator would be like:
222219

223220
| Epoch | id | sum(v0) | bool_and(v1) | weighted_avg(v0, w0) | max(v0)* |
224221
| ----- | ---- | ------- | ------------------- | -------------------- | ---------- |
225222
| 1 | 0 | sum = 1 | false = 1, true = 0 | encode(1,2) = b'XXX' | output = 1 |
226223
| 2 | 0 | sum = 2 | false = 0, true = 1 | encode(2,1) = b'YYY' | output = 2 |
227224

228-
* For **append-only** aggregate functions (e.g. max, min, first, last, string_agg...), their states are all input values maintained in seperate "materialized input" tables. For backward compatibility, their values in the state table are still aggregate results.
225+
Note: for **append-only** aggregate functions (e.g. max, min, first, last, string_agg...), their states are all input values maintained in seperate "materialized input" tables. For backward compatibility, their values in the state table are still aggregate results.
229226

230227
The output would be:
231228

@@ -253,7 +250,7 @@ class WeightedAvg:
253250
self.sum = 0
254251
self.count = 0
255252

256-
def get_value(self) -> int:
253+
def get_result(self) -> int:
257254
if self.count == 0:
258255
return None
259256
else:

0 commit comments

Comments
 (0)