Skip to content

Commit 6f3eb54

Browse files
wsx-ucbchenzl25
andauthored
feat(temporal-join): Temporal join executor (risingwavelabs#8412)
Co-authored-by: Dylan Chen <[email protected]> Co-authored-by: Dylan <[email protected]>
1 parent 5294f43 commit 6f3eb54

File tree

2 files changed

+288
-0
lines changed

2 files changed

+288
-0
lines changed

src/stream/src/executor/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ mod sort_buffer;
8888
pub mod source;
8989
mod stream_reader;
9090
pub mod subtask;
91+
mod temporal_join;
9192
mod top_n;
9293
mod union;
9394
mod watermark;
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
// Copyright 2023 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::alloc::Global;
16+
use std::sync::Arc;
17+
18+
use either::Either;
19+
use futures::stream::{self, PollNext};
20+
use futures::{StreamExt, TryStreamExt};
21+
use futures_async_stream::try_stream;
22+
use local_stats_alloc::{SharedStatsAlloc, StatsAlloc};
23+
use lru::DefaultHasher;
24+
use risingwave_common::array::{Op, StreamChunk};
25+
use risingwave_common::catalog::Schema;
26+
use risingwave_common::row::{OwnedRow, Row, RowExt};
27+
use risingwave_common::util::iter_util::ZipEqFast;
28+
use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch};
29+
use risingwave_storage::table::batch_table::storage_table::StorageTable;
30+
use risingwave_storage::StateStore;
31+
32+
use super::{Barrier, Executor, Message, MessageStream, StreamExecutorError, StreamExecutorResult};
33+
use crate::cache::{new_with_hasher_in, ManagedLruCache};
34+
use crate::common::StreamChunkBuilder;
35+
use crate::executor::monitor::StreamingMetrics;
36+
use crate::executor::{ActorContextRef, BoxedExecutor, JoinType, JoinTypePrimitive, PkIndices};
37+
use crate::task::AtomicU64Ref;
38+
39+
pub struct TemporalJoinExecutor<S: StateStore, const T: JoinTypePrimitive> {
40+
// TODO: update metrics
41+
#[allow(dead_code)]
42+
ctx: ActorContextRef,
43+
left: BoxedExecutor,
44+
right: BoxedExecutor,
45+
right_table: TemporalSide<S>,
46+
left_join_keys: Vec<usize>,
47+
right_join_keys: Vec<usize>,
48+
null_safe: Vec<bool>,
49+
output_indices: Vec<usize>,
50+
pk_indices: PkIndices,
51+
schema: Schema,
52+
chunk_size: usize,
53+
identity: String,
54+
// TODO: update metrics
55+
#[allow(dead_code)]
56+
metrics: Arc<StreamingMetrics>,
57+
}
58+
59+
struct TemporalSide<S: StateStore> {
60+
source: StorageTable<S>,
61+
cache: ManagedLruCache<OwnedRow, Option<OwnedRow>, DefaultHasher, SharedStatsAlloc<Global>>,
62+
}
63+
64+
impl<S: StateStore> TemporalSide<S> {
65+
async fn lookup(
66+
&mut self,
67+
key: impl Row,
68+
epoch: HummockEpoch,
69+
) -> StreamExecutorResult<Option<OwnedRow>> {
70+
let key = key.into_owned_row();
71+
Ok(match self.cache.get(&key) {
72+
Some(res) => res.clone(),
73+
None => {
74+
let res = self
75+
.source
76+
.get_row(key.clone(), HummockReadEpoch::NoWait(epoch))
77+
.await?;
78+
self.cache.put(key, res.clone());
79+
res
80+
}
81+
})
82+
}
83+
84+
fn update(&mut self, payload: Vec<StreamChunk>, join_keys: &[usize], epoch: u64) {
85+
payload.iter().flat_map(|c| c.rows()).for_each(|(op, row)| {
86+
let key = row.project(join_keys).into_owned_row();
87+
if let Some(value) = self.cache.get_mut(&key) {
88+
match op {
89+
Op::Insert | Op::UpdateInsert => *value = Some(row.into_owned_row()),
90+
Op::Delete | Op::UpdateDelete => *value = None,
91+
};
92+
}
93+
});
94+
self.cache.update_epoch(epoch);
95+
}
96+
}
97+
98+
enum InternalMessage {
99+
Chunk(StreamChunk),
100+
Barrier(Vec<StreamChunk>, Barrier),
101+
}
102+
103+
#[try_stream(ok = StreamChunk, error = StreamExecutorError)]
104+
pub async fn chunks_until_barrier(stream: impl MessageStream, expected_barrier: Barrier) {
105+
#[for_await]
106+
for item in stream {
107+
match item? {
108+
Message::Watermark(_) => {
109+
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
110+
}
111+
Message::Chunk(c) => yield c,
112+
Message::Barrier(b) if b.epoch != expected_barrier.epoch => {
113+
return Err(StreamExecutorError::align_barrier(expected_barrier, b));
114+
}
115+
Message::Barrier(_) => return Ok(()),
116+
}
117+
}
118+
}
119+
120+
// Align the left and right inputs according to their barriers,
121+
// such that in the produced stream, an aligned interval starts with
122+
// any number of `InternalMessage::Chunk(left_chunk)` and followed by
123+
// `InternalMessage::Barrier(right_chunks, barrier)`.
124+
#[try_stream(ok = InternalMessage, error = StreamExecutorError)]
125+
async fn align_input(left: Box<dyn Executor>, right: Box<dyn Executor>) {
126+
let mut left = Box::pin(left.execute());
127+
let mut right = Box::pin(right.execute());
128+
// Keep producing intervals until stream exhaustion or errors.
129+
loop {
130+
let mut right_chunks = vec![];
131+
// Produce an aligned interval.
132+
'inner: loop {
133+
let mut combined = stream::select_with_strategy(
134+
left.by_ref().map(Either::Left),
135+
right.by_ref().map(Either::Right),
136+
|_: &mut ()| PollNext::Left,
137+
);
138+
match combined.next().await {
139+
Some(Either::Left(Ok(Message::Chunk(c)))) => yield InternalMessage::Chunk(c),
140+
Some(Either::Right(Ok(Message::Chunk(c)))) => right_chunks.push(c),
141+
Some(Either::Left(Ok(Message::Barrier(b)))) => {
142+
let mut remain = chunks_until_barrier(right.by_ref(), b.clone())
143+
.try_collect()
144+
.await?;
145+
right_chunks.append(&mut remain);
146+
yield InternalMessage::Barrier(right_chunks, b);
147+
break 'inner;
148+
}
149+
Some(Either::Right(Ok(Message::Barrier(b)))) => {
150+
#[for_await]
151+
for chunk in chunks_until_barrier(left.by_ref(), b.clone()) {
152+
yield InternalMessage::Chunk(chunk?);
153+
}
154+
yield InternalMessage::Barrier(right_chunks, b);
155+
break 'inner;
156+
}
157+
Some(Either::Left(Err(e)) | Either::Right(Err(e))) => return Err(e),
158+
Some(
159+
Either::Left(Ok(Message::Watermark(_)))
160+
| Either::Right(Ok(Message::Watermark(_))),
161+
) => todo!("https://github.com/risingwavelabs/risingwave/issues/6042"),
162+
None => return Ok(()),
163+
}
164+
}
165+
}
166+
}
167+
168+
impl<S: StateStore, const T: JoinTypePrimitive> TemporalJoinExecutor<S, T> {
169+
#[allow(dead_code)]
170+
#[allow(clippy::too_many_arguments)]
171+
pub fn new(
172+
ctx: ActorContextRef,
173+
left: BoxedExecutor,
174+
right: BoxedExecutor,
175+
table: StorageTable<S>,
176+
left_join_keys: Vec<usize>,
177+
right_join_keys: Vec<usize>,
178+
null_safe: Vec<bool>,
179+
pk_indices: PkIndices,
180+
output_indices: Vec<usize>,
181+
executor_id: u64,
182+
watermark_epoch: AtomicU64Ref,
183+
metrics: Arc<StreamingMetrics>,
184+
chunk_size: usize,
185+
) -> Self {
186+
let schema_fields = [left.schema().fields.clone(), right.schema().fields.clone()].concat();
187+
188+
let schema: Schema = output_indices
189+
.iter()
190+
.map(|&idx| schema_fields[idx].clone())
191+
.collect();
192+
193+
let alloc = StatsAlloc::new(Global).shared();
194+
195+
let cache = new_with_hasher_in(watermark_epoch, DefaultHasher::default(), alloc);
196+
197+
Self {
198+
ctx,
199+
left,
200+
right,
201+
right_table: TemporalSide {
202+
source: table,
203+
cache,
204+
},
205+
left_join_keys,
206+
right_join_keys,
207+
null_safe,
208+
output_indices,
209+
schema,
210+
chunk_size,
211+
pk_indices,
212+
identity: format!("TemporalJoinExecutor {:X}", executor_id),
213+
metrics,
214+
}
215+
}
216+
217+
#[try_stream(ok = Message, error = StreamExecutorError)]
218+
async fn into_stream(mut self) {
219+
let (left_map, right_map) = StreamChunkBuilder::get_i2o_mapping(
220+
self.output_indices.iter().cloned(),
221+
self.left.schema().len(),
222+
self.right.schema().len(),
223+
);
224+
225+
let mut prev_epoch = None;
226+
#[for_await]
227+
for msg in align_input(self.left, self.right) {
228+
match msg? {
229+
InternalMessage::Chunk(chunk) => {
230+
let mut builder = StreamChunkBuilder::new(
231+
self.chunk_size,
232+
&self.schema.data_types(),
233+
left_map.clone(),
234+
right_map.clone(),
235+
);
236+
let epoch = prev_epoch.expect("Chunk data should come after some barrier.");
237+
for (op, row) in chunk.rows() {
238+
let key = row.project(&self.left_join_keys);
239+
if key
240+
.iter()
241+
.zip_eq_fast(self.null_safe.iter())
242+
.any(|(datum, can_null)| datum.is_none() && !*can_null)
243+
{
244+
continue;
245+
}
246+
if let Some(right_row) = self.right_table.lookup(key, epoch).await? {
247+
if let Some(chunk) = builder.append_row(op, row, &right_row) {
248+
yield Message::Chunk(chunk);
249+
}
250+
} else if T == JoinType::LeftOuter {
251+
if let Some(chunk) = builder.append_row_update(op, row) {
252+
yield Message::Chunk(chunk);
253+
}
254+
}
255+
}
256+
if let Some(chunk) = builder.take() {
257+
yield Message::Chunk(chunk);
258+
}
259+
}
260+
InternalMessage::Barrier(updates, barrier) => {
261+
prev_epoch = Some(barrier.epoch.curr);
262+
self.right_table
263+
.update(updates, &self.right_join_keys, barrier.epoch.curr);
264+
yield Message::Barrier(barrier)
265+
}
266+
}
267+
}
268+
}
269+
}
270+
271+
impl<S: StateStore, const T: JoinTypePrimitive> Executor for TemporalJoinExecutor<S, T> {
272+
fn execute(self: Box<Self>) -> super::BoxedMessageStream {
273+
self.into_stream().boxed()
274+
}
275+
276+
fn schema(&self) -> &Schema {
277+
&self.schema
278+
}
279+
280+
fn pk_indices(&self) -> super::PkIndicesRef<'_> {
281+
&self.pk_indices
282+
}
283+
284+
fn identity(&self) -> &str {
285+
self.identity.as_str()
286+
}
287+
}

0 commit comments

Comments
 (0)