Skip to content

feat: handle watermark in hash join #7379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 163 additions & 28 deletions src/stream/src/executor/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -21,6 +21,7 @@ use fixedbitset::FixedBitSet;
use futures::{pin_mut, StreamExt};
use futures_async_stream::try_stream;
use itertools::Itertools;
use multimap::MultiMap;
use risingwave_common::array::{Op, RowRef, StreamChunk};
use risingwave_common::catalog::Schema;
use risingwave_common::hash::HashKey;
Expand All @@ -35,8 +36,10 @@ use super::barrier_align::*;
use super::error::{StreamExecutorError, StreamExecutorResult};
use super::managed_state::join::*;
use super::monitor::StreamingMetrics;
use super::watermark::*;
use super::{
ActorContextRef, BoxedExecutor, BoxedMessageStream, Executor, Message, PkIndices, PkIndicesRef,
Watermark,
};
use crate::common::table::state_table::StateTable;
use crate::common::{InfallibleExpression, StreamChunkBuilder};
Expand Down Expand Up @@ -163,6 +166,7 @@ struct JoinSide<K: HashKey, S: StateStore> {
start_pos: usize,
/// The mapping from input indices of a side to output columes.
i2o_mapping: Vec<(usize, usize)>,
i2o_mapping_indexed: MultiMap<usize, usize>,
/// Whether degree table is needed for this side.
need_degree_table: bool,
}
Expand All @@ -175,6 +179,7 @@ impl<K: HashKey, S: StateStore> std::fmt::Debug for JoinSide<K, S> {
.field("col_types", &self.all_data_types)
.field("start_pos", &self.start_pos)
.field("i2o_mapping", &self.i2o_mapping)
.field("need_degree_table", &self.need_degree_table)
.finish()
}
}
Expand Down Expand Up @@ -236,6 +241,9 @@ pub struct HashJoinExecutor<K: HashKey, S: StateStore, const T: JoinTypePrimitiv
metrics: Arc<StreamingMetrics>,
/// The maximum size of the chunk produced by executor at a time
chunk_size: usize,

/// watermark column index -> `BufferedWatermarks`
watermark_buffers: BTreeMap<usize, BufferedWatermarks<SideTypePrimitive>>,
}

impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> std::fmt::Debug
Expand Down Expand Up @@ -524,6 +532,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
StreamChunkBuilder::get_i2o_mapping(output_indices.iter().cloned(), left_len, right_len)
};

let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied());
let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied());

let watermark_buffers = BTreeMap::new();

Self {
ctx: ctx.clone(),
input_l: Some(input_l),
Expand All @@ -548,9 +561,10 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
), // TODO: decide the target cap
join_key_indices: join_key_indices_l,
all_data_types: state_all_data_types_l,
i2o_mapping: left_to_output,
i2o_mapping_indexed: l2o_indexed,
pk_indices: state_pk_indices_l,
start_pos: 0,
i2o_mapping: left_to_output,
need_degree_table: need_degree_table_l,
},
side_r: JoinSide {
Expand All @@ -574,6 +588,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
pk_indices: state_pk_indices_r,
start_pos: side_l_column_n,
i2o_mapping: right_to_output,
i2o_mapping_indexed: r2o_indexed,
need_degree_table: need_degree_table_r,
},
pk_indices,
Expand All @@ -583,6 +598,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
append_only_optimize,
metrics,
chunk_size,
watermark_buffers,
}
}

Expand Down Expand Up @@ -617,8 +633,15 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
.with_label_values(&[&actor_id_str])
.inc_by(start_time.elapsed().as_nanos() as u64);
match msg? {
AlignedMessage::WatermarkLeft(_) | AlignedMessage::WatermarkRight(_) => {
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
AlignedMessage::WatermarkLeft(watermark) => {
for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? {
yield Message::Watermark(watermark_to_emit);
}
}
AlignedMessage::WatermarkRight(watermark) => {
for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? {
yield Message::Watermark(watermark_to_emit);
}
}
AlignedMessage::Left(chunk) => {
let mut left_time = Duration::from_nanos(0);
Expand All @@ -636,13 +659,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
self.chunk_size,
) {
left_time += left_start_time.elapsed();
yield chunk.map(|v| match v {
Message::Watermark(_) => {
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
}
Message::Chunk(chunk) => Message::Chunk(chunk),
barrier @ Message::Barrier(_) => barrier,
})?;
yield Message::Chunk(chunk?);
left_start_time = minstant::Instant::now();
}
left_time += left_start_time.elapsed();
Expand All @@ -667,13 +684,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
self.chunk_size,
) {
right_time += right_start_time.elapsed();
yield chunk.map(|v| match v {
Message::Watermark(_) => {
todo!("https://github.com/risingwavelabs/risingwave/issues/6042")
}
Message::Chunk(chunk) => Message::Chunk(chunk),
barrier @ Message::Barrier(_) => barrier,
})?;
yield Message::Chunk(chunk?);
right_start_time = minstant::Instant::now();
}
right_time += right_start_time.elapsed();
Expand Down Expand Up @@ -740,6 +751,53 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
Ok(())
}

fn handle_watermark(
&mut self,
side: SideTypePrimitive,
watermark: Watermark,
) -> StreamExecutorResult<Vec<Watermark>> {
let (side_update, side_match) = if side == SideType::Left {
(&mut self.side_l, &mut self.side_r)
} else {
(&mut self.side_r, &mut self.side_l)
};

// State cleaning
if side_update.join_key_indices[0] == watermark.col_idx {
side_match.ht.update_watermark(watermark.val.clone());
}

// Select watermarks to yield.
let wm_in_jk = side_update
.join_key_indices
.iter()
.positions(|idx| *idx == watermark.col_idx);
let mut watermarks_to_emit = vec![];
for idx in wm_in_jk {
let buffers = self.watermark_buffers.entry(idx).or_insert_with(|| {
BufferedWatermarks::with_ids(vec![SideType::Left, SideType::Right])
});
if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) {
let empty_indices = vec![];
let output_indices = side_update
.i2o_mapping_indexed
.get_vec(&side_update.join_key_indices[idx])
.unwrap_or(&empty_indices)
.iter()
.chain(
side_match
.i2o_mapping_indexed
.get_vec(&side_match.join_key_indices[idx])
.unwrap_or(&empty_indices),
);
for output_idx in output_indices {
watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx));
}
};
}
Ok(watermarks_to_emit)
}

/// the data the hash table and match the coming
/// data chunk with the executor state
async fn hash_eq_match(
Expand Down Expand Up @@ -770,7 +828,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
OwnedRow::new(new_row)
}

#[try_stream(ok = Message, error = StreamExecutorError)]
#[try_stream(ok = StreamChunk, error = StreamExecutorError)]
#[expect(clippy::too_many_arguments)]
async fn eq_join_oneside<'a, const SIDE: SideTypePrimitive>(
ctx: &'a ActorContextRef,
Expand Down Expand Up @@ -839,7 +897,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
if let Some(chunk) = hashjoin_chunk_builder
.with_match_on_insert(&row, &matched_row)
{
yield Message::Chunk(chunk);
yield chunk;
}
}
if side_match.need_degree_table {
Expand All @@ -860,19 +918,19 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
if let Some(chunk) =
hashjoin_chunk_builder.forward_if_not_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}
} else if let Some(chunk) =
hashjoin_chunk_builder.forward_exactly_once_if_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}
// Insert back the state taken from ht.
side_match.ht.update_state(key, matched_rows);
} else if let Some(chunk) =
hashjoin_chunk_builder.forward_if_not_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}

if append_only_optimize && let Some(row) = append_only_matched_row {
Expand All @@ -899,7 +957,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
if let Some(chunk) = hashjoin_chunk_builder
.with_match_on_delete(&row, &matched_row)
{
yield Message::Chunk(chunk);
yield chunk;
}
}
}
Expand All @@ -908,19 +966,19 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
if let Some(chunk) =
hashjoin_chunk_builder.forward_if_not_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}
} else if let Some(chunk) =
hashjoin_chunk_builder.forward_exactly_once_if_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}
// Insert back the state taken from ht.
side_match.ht.update_state(key, matched_rows);
} else if let Some(chunk) =
hashjoin_chunk_builder.forward_if_not_matched(op, row)
{
yield Message::Chunk(chunk);
yield chunk;
}
if append_only_optimize {
unreachable!();
Expand All @@ -933,7 +991,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
}
}
if let Some(chunk) = hashjoin_chunk_builder.take() {
yield Message::Chunk(chunk);
yield chunk;
}
}
}
Expand All @@ -946,6 +1004,7 @@ mod tests {
use risingwave_common::array::*;
use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
use risingwave_common::hash::{Key128, Key64};
use risingwave_common::types::ScalarImpl;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::expr::expr_binary_nonnull::new_binary_expr;
use risingwave_expr::expr::InputRefExpression;
Expand Down Expand Up @@ -1048,11 +1107,13 @@ mod tests {
2,
)
.await;

let schema_len = match T {
JoinType::LeftSemi | JoinType::LeftAnti => source_l.schema().len(),
JoinType::RightSemi | JoinType::RightAnti => source_r.schema().len(),
_ => source_l.schema().len() + source_r.schema().len(),
};

let executor = HashJoinExecutor::<Key64, MemoryStateStore, T>::new(
ActorContext::create(123),
Box::new(source_l),
Expand Down Expand Up @@ -1125,6 +1186,7 @@ mod tests {
JoinType::RightSemi | JoinType::RightAnti => source_r.schema().len(),
_ => source_l.schema().len() + source_r.schema().len(),
};

let executor = HashJoinExecutor::<Key128, MemoryStateStore, T>::new(
ActorContext::create(123),
Box::new(source_l),
Expand Down Expand Up @@ -2923,4 +2985,77 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_streaming_hash_join_watermark() -> StreamExecutorResult<()> {
let (mut tx_l, mut tx_r, mut hash_join) =
create_executor::<{ JoinType::Inner }>(true, false).await;

// push the init barrier for left and right
tx_l.push_barrier(1, false);
tx_r.push_barrier(1, false);
hash_join.next_unwrap_ready_barrier()?;

tx_l.push_int64_watermark(0, 100);

tx_l.push_int64_watermark(0, 200);

tx_l.push_barrier(2, false);
tx_r.push_barrier(2, false);
hash_join.next_unwrap_ready_barrier()?;

tx_r.push_int64_watermark(0, 50);

let w1 = hash_join.next().await.unwrap().unwrap();
let w1 = w1.as_watermark().unwrap();

let w2 = hash_join.next().await.unwrap().unwrap();
let w2 = w2.as_watermark().unwrap();

tx_r.push_int64_watermark(0, 100);

let w3 = hash_join.next().await.unwrap().unwrap();
let w3 = w3.as_watermark().unwrap();

let w4 = hash_join.next().await.unwrap().unwrap();
let w4 = w4.as_watermark().unwrap();

assert_eq!(
w1,
&Watermark {
col_idx: 2,
data_type: DataType::Int64,
val: ScalarImpl::Int64(50)
}
);

assert_eq!(
w2,
&Watermark {
col_idx: 0,
data_type: DataType::Int64,
val: ScalarImpl::Int64(50)
}
);

assert_eq!(
w3,
&Watermark {
col_idx: 2,
data_type: DataType::Int64,
val: ScalarImpl::Int64(100)
}
);

assert_eq!(
w4,
&Watermark {
col_idx: 0,
data_type: DataType::Int64,
val: ScalarImpl::Int64(100)
}
);

Ok(())
}
}
6 changes: 6 additions & 0 deletions src/stream/src/executor/managed_state/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
}
}

pub fn update_watermark(&mut self, watermark: ScalarImpl) {
// TODO: remove data in cache.
self.state.table.update_watermark(watermark.clone());
self.degree_state.table.update_watermark(watermark);
}

/// Take the state for the given `key` out of the hash table and return it. One **MUST** call
/// `update_state` after some operations to put the state back.
///
Expand Down
Loading