Skip to content

Commit ca339df

Browse files
authored
feat: handle watermark in hash join (#7379)
- align (take minimal) watermark from left and right - state cleaning Approved-By: soundOfDestiny Approved-By: st1page
1 parent b63df21 commit ca339df

File tree

4 files changed

+182
-33
lines changed

4 files changed

+182
-33
lines changed

src/stream/src/executor/hash_join.rs

Lines changed: 163 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::collections::HashSet;
15+
use std::collections::{BTreeMap, HashSet};
1616
use std::sync::Arc;
1717
use std::time::Duration;
1818

@@ -21,6 +21,7 @@ use fixedbitset::FixedBitSet;
2121
use futures::{pin_mut, StreamExt};
2222
use futures_async_stream::try_stream;
2323
use itertools::Itertools;
24+
use multimap::MultiMap;
2425
use risingwave_common::array::{Op, RowRef, StreamChunk};
2526
use risingwave_common::catalog::Schema;
2627
use risingwave_common::hash::HashKey;
@@ -35,8 +36,10 @@ use super::barrier_align::*;
3536
use super::error::{StreamExecutorError, StreamExecutorResult};
3637
use super::managed_state::join::*;
3738
use super::monitor::StreamingMetrics;
39+
use super::watermark::*;
3840
use super::{
3941
ActorContextRef, BoxedExecutor, BoxedMessageStream, Executor, Message, PkIndices, PkIndicesRef,
42+
Watermark,
4043
};
4144
use crate::common::table::state_table::StateTable;
4245
use crate::common::{InfallibleExpression, StreamChunkBuilder};
@@ -163,6 +166,7 @@ struct JoinSide<K: HashKey, S: StateStore> {
163166
start_pos: usize,
164167
/// The mapping from input indices of a side to output columes.
165168
i2o_mapping: Vec<(usize, usize)>,
169+
i2o_mapping_indexed: MultiMap<usize, usize>,
166170
/// Whether degree table is needed for this side.
167171
need_degree_table: bool,
168172
}
@@ -175,6 +179,7 @@ impl<K: HashKey, S: StateStore> std::fmt::Debug for JoinSide<K, S> {
175179
.field("col_types", &self.all_data_types)
176180
.field("start_pos", &self.start_pos)
177181
.field("i2o_mapping", &self.i2o_mapping)
182+
.field("need_degree_table", &self.need_degree_table)
178183
.finish()
179184
}
180185
}
@@ -236,6 +241,9 @@ pub struct HashJoinExecutor<K: HashKey, S: StateStore, const T: JoinTypePrimitiv
236241
metrics: Arc<StreamingMetrics>,
237242
/// The maximum size of the chunk produced by executor at a time
238243
chunk_size: usize,
244+
245+
/// watermark column index -> `BufferedWatermarks`
246+
watermark_buffers: BTreeMap<usize, BufferedWatermarks<SideTypePrimitive>>,
239247
}
240248

241249
impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> std::fmt::Debug
@@ -524,6 +532,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
524532
StreamChunkBuilder::get_i2o_mapping(output_indices.iter().cloned(), left_len, right_len)
525533
};
526534

535+
let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied());
536+
let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied());
537+
538+
let watermark_buffers = BTreeMap::new();
539+
527540
Self {
528541
ctx: ctx.clone(),
529542
input_l: Some(input_l),
@@ -548,9 +561,10 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
548561
), // TODO: decide the target cap
549562
join_key_indices: join_key_indices_l,
550563
all_data_types: state_all_data_types_l,
564+
i2o_mapping: left_to_output,
565+
i2o_mapping_indexed: l2o_indexed,
551566
pk_indices: state_pk_indices_l,
552567
start_pos: 0,
553-
i2o_mapping: left_to_output,
554568
need_degree_table: need_degree_table_l,
555569
},
556570
side_r: JoinSide {
@@ -574,6 +588,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
574588
pk_indices: state_pk_indices_r,
575589
start_pos: side_l_column_n,
576590
i2o_mapping: right_to_output,
591+
i2o_mapping_indexed: r2o_indexed,
577592
need_degree_table: need_degree_table_r,
578593
},
579594
pk_indices,
@@ -583,6 +598,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
583598
append_only_optimize,
584599
metrics,
585600
chunk_size,
601+
watermark_buffers,
586602
}
587603
}
588604

@@ -617,8 +633,15 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
617633
.with_label_values(&[&actor_id_str])
618634
.inc_by(start_time.elapsed().as_nanos() as u64);
619635
match msg? {
620-
AlignedMessage::WatermarkLeft(_) | AlignedMessage::WatermarkRight(_) => {
621-
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
636+
AlignedMessage::WatermarkLeft(watermark) => {
637+
for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? {
638+
yield Message::Watermark(watermark_to_emit);
639+
}
640+
}
641+
AlignedMessage::WatermarkRight(watermark) => {
642+
for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? {
643+
yield Message::Watermark(watermark_to_emit);
644+
}
622645
}
623646
AlignedMessage::Left(chunk) => {
624647
let mut left_time = Duration::from_nanos(0);
@@ -636,13 +659,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
636659
self.chunk_size,
637660
) {
638661
left_time += left_start_time.elapsed();
639-
yield chunk.map(|v| match v {
640-
Message::Watermark(_) => {
641-
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
642-
}
643-
Message::Chunk(chunk) => Message::Chunk(chunk),
644-
barrier @ Message::Barrier(_) => barrier,
645-
})?;
662+
yield Message::Chunk(chunk?);
646663
left_start_time = minstant::Instant::now();
647664
}
648665
left_time += left_start_time.elapsed();
@@ -667,13 +684,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
667684
self.chunk_size,
668685
) {
669686
right_time += right_start_time.elapsed();
670-
yield chunk.map(|v| match v {
671-
Message::Watermark(_) => {
672-
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
673-
}
674-
Message::Chunk(chunk) => Message::Chunk(chunk),
675-
barrier @ Message::Barrier(_) => barrier,
676-
})?;
687+
yield Message::Chunk(chunk?);
677688
right_start_time = minstant::Instant::now();
678689
}
679690
right_time += right_start_time.elapsed();
@@ -740,6 +751,53 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
740751
Ok(())
741752
}
742753

754+
fn handle_watermark(
755+
&mut self,
756+
side: SideTypePrimitive,
757+
watermark: Watermark,
758+
) -> StreamExecutorResult<Vec<Watermark>> {
759+
let (side_update, side_match) = if side == SideType::Left {
760+
(&mut self.side_l, &mut self.side_r)
761+
} else {
762+
(&mut self.side_r, &mut self.side_l)
763+
};
764+
765+
// State cleaning
766+
if side_update.join_key_indices[0] == watermark.col_idx {
767+
side_match.ht.update_watermark(watermark.val.clone());
768+
}
769+
770+
// Select watermarks to yield.
771+
let wm_in_jk = side_update
772+
.join_key_indices
773+
.iter()
774+
.positions(|idx| *idx == watermark.col_idx);
775+
let mut watermarks_to_emit = vec![];
776+
for idx in wm_in_jk {
777+
let buffers = self.watermark_buffers.entry(idx).or_insert_with(|| {
778+
BufferedWatermarks::with_ids(vec![SideType::Left, SideType::Right])
779+
});
780+
if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) {
781+
let empty_indices = vec![];
782+
let output_indices = side_update
783+
.i2o_mapping_indexed
784+
.get_vec(&side_update.join_key_indices[idx])
785+
.unwrap_or(&empty_indices)
786+
.iter()
787+
.chain(
788+
side_match
789+
.i2o_mapping_indexed
790+
.get_vec(&side_match.join_key_indices[idx])
791+
.unwrap_or(&empty_indices),
792+
);
793+
for output_idx in output_indices {
794+
watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx));
795+
}
796+
};
797+
}
798+
Ok(watermarks_to_emit)
799+
}
800+
743801
/// the data the hash table and match the coming
744802
/// data chunk with the executor state
745803
async fn hash_eq_match(
@@ -770,7 +828,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
770828
OwnedRow::new(new_row)
771829
}
772830

773-
#[try_stream(ok = Message, error = StreamExecutorError)]
831+
#[try_stream(ok = StreamChunk, error = StreamExecutorError)]
774832
#[expect(clippy::too_many_arguments)]
775833
async fn eq_join_oneside<'a, const SIDE: SideTypePrimitive>(
776834
ctx: &'a ActorContextRef,
@@ -839,7 +897,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
839897
if let Some(chunk) = hashjoin_chunk_builder
840898
.with_match_on_insert(&row, &matched_row)
841899
{
842-
yield Message::Chunk(chunk);
900+
yield chunk;
843901
}
844902
}
845903
if side_match.need_degree_table {
@@ -860,19 +918,19 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
860918
if let Some(chunk) =
861919
hashjoin_chunk_builder.forward_if_not_matched(op, row)
862920
{
863-
yield Message::Chunk(chunk);
921+
yield chunk;
864922
}
865923
} else if let Some(chunk) =
866924
hashjoin_chunk_builder.forward_exactly_once_if_matched(op, row)
867925
{
868-
yield Message::Chunk(chunk);
926+
yield chunk;
869927
}
870928
// Insert back the state taken from ht.
871929
side_match.ht.update_state(key, matched_rows);
872930
} else if let Some(chunk) =
873931
hashjoin_chunk_builder.forward_if_not_matched(op, row)
874932
{
875-
yield Message::Chunk(chunk);
933+
yield chunk;
876934
}
877935

878936
if append_only_optimize && let Some(row) = append_only_matched_row {
@@ -899,7 +957,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
899957
if let Some(chunk) = hashjoin_chunk_builder
900958
.with_match_on_delete(&row, &matched_row)
901959
{
902-
yield Message::Chunk(chunk);
960+
yield chunk;
903961
}
904962
}
905963
}
@@ -908,19 +966,19 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
908966
if let Some(chunk) =
909967
hashjoin_chunk_builder.forward_if_not_matched(op, row)
910968
{
911-
yield Message::Chunk(chunk);
969+
yield chunk;
912970
}
913971
} else if let Some(chunk) =
914972
hashjoin_chunk_builder.forward_exactly_once_if_matched(op, row)
915973
{
916-
yield Message::Chunk(chunk);
974+
yield chunk;
917975
}
918976
// Insert back the state taken from ht.
919977
side_match.ht.update_state(key, matched_rows);
920978
} else if let Some(chunk) =
921979
hashjoin_chunk_builder.forward_if_not_matched(op, row)
922980
{
923-
yield Message::Chunk(chunk);
981+
yield chunk;
924982
}
925983
if append_only_optimize {
926984
unreachable!();
@@ -933,7 +991,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
933991
}
934992
}
935993
if let Some(chunk) = hashjoin_chunk_builder.take() {
936-
yield Message::Chunk(chunk);
994+
yield chunk;
937995
}
938996
}
939997
}
@@ -946,6 +1004,7 @@ mod tests {
9461004
use risingwave_common::array::*;
9471005
use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
9481006
use risingwave_common::hash::{Key128, Key64};
1007+
use risingwave_common::types::ScalarImpl;
9491008
use risingwave_common::util::sort_util::OrderType;
9501009
use risingwave_expr::expr::expr_binary_nonnull::new_binary_expr;
9511010
use risingwave_expr::expr::InputRefExpression;
@@ -1048,11 +1107,13 @@ mod tests {
10481107
2,
10491108
)
10501109
.await;
1110+
10511111
let schema_len = match T {
10521112
JoinType::LeftSemi | JoinType::LeftAnti => source_l.schema().len(),
10531113
JoinType::RightSemi | JoinType::RightAnti => source_r.schema().len(),
10541114
_ => source_l.schema().len() + source_r.schema().len(),
10551115
};
1116+
10561117
let executor = HashJoinExecutor::<Key64, MemoryStateStore, T>::new(
10571118
ActorContext::create(123),
10581119
Box::new(source_l),
@@ -1125,6 +1186,7 @@ mod tests {
11251186
JoinType::RightSemi | JoinType::RightAnti => source_r.schema().len(),
11261187
_ => source_l.schema().len() + source_r.schema().len(),
11271188
};
1189+
11281190
let executor = HashJoinExecutor::<Key128, MemoryStateStore, T>::new(
11291191
ActorContext::create(123),
11301192
Box::new(source_l),
@@ -2923,4 +2985,77 @@ mod tests {
29232985

29242986
Ok(())
29252987
}
2988+
2989+
#[tokio::test]
2990+
async fn test_streaming_hash_join_watermark() -> StreamExecutorResult<()> {
2991+
let (mut tx_l, mut tx_r, mut hash_join) =
2992+
create_executor::<{ JoinType::Inner }>(true, false).await;
2993+
2994+
// push the init barrier for left and right
2995+
tx_l.push_barrier(1, false);
2996+
tx_r.push_barrier(1, false);
2997+
hash_join.next_unwrap_ready_barrier()?;
2998+
2999+
tx_l.push_int64_watermark(0, 100);
3000+
3001+
tx_l.push_int64_watermark(0, 200);
3002+
3003+
tx_l.push_barrier(2, false);
3004+
tx_r.push_barrier(2, false);
3005+
hash_join.next_unwrap_ready_barrier()?;
3006+
3007+
tx_r.push_int64_watermark(0, 50);
3008+
3009+
let w1 = hash_join.next().await.unwrap().unwrap();
3010+
let w1 = w1.as_watermark().unwrap();
3011+
3012+
let w2 = hash_join.next().await.unwrap().unwrap();
3013+
let w2 = w2.as_watermark().unwrap();
3014+
3015+
tx_r.push_int64_watermark(0, 100);
3016+
3017+
let w3 = hash_join.next().await.unwrap().unwrap();
3018+
let w3 = w3.as_watermark().unwrap();
3019+
3020+
let w4 = hash_join.next().await.unwrap().unwrap();
3021+
let w4 = w4.as_watermark().unwrap();
3022+
3023+
assert_eq!(
3024+
w1,
3025+
&Watermark {
3026+
col_idx: 2,
3027+
data_type: DataType::Int64,
3028+
val: ScalarImpl::Int64(50)
3029+
}
3030+
);
3031+
3032+
assert_eq!(
3033+
w2,
3034+
&Watermark {
3035+
col_idx: 0,
3036+
data_type: DataType::Int64,
3037+
val: ScalarImpl::Int64(50)
3038+
}
3039+
);
3040+
3041+
assert_eq!(
3042+
w3,
3043+
&Watermark {
3044+
col_idx: 2,
3045+
data_type: DataType::Int64,
3046+
val: ScalarImpl::Int64(100)
3047+
}
3048+
);
3049+
3050+
assert_eq!(
3051+
w4,
3052+
&Watermark {
3053+
col_idx: 0,
3054+
data_type: DataType::Int64,
3055+
val: ScalarImpl::Int64(100)
3056+
}
3057+
);
3058+
3059+
Ok(())
3060+
}
29263061
}

src/stream/src/executor/managed_state/join/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,12 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
310310
}
311311
}
312312

313+
pub fn update_watermark(&mut self, watermark: ScalarImpl) {
314+
// TODO: remove data in cache.
315+
self.state.table.update_watermark(watermark.clone());
316+
self.degree_state.table.update_watermark(watermark);
317+
}
318+
313319
/// Take the state for the given `key` out of the hash table and return it. One **MUST** call
314320
/// `update_state` after some operations to put the state back.
315321
///

0 commit comments

Comments
 (0)