Skip to content

Commit 632423a

Browse files
authored
fix(optimizer): fix hash join distribution (risingwavelabs#8598)
1 parent 582307d commit 632423a

File tree

4 files changed

+37
-2
lines changed

4 files changed

+37
-2
lines changed

src/frontend/planner_test/tests/testdata/join.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,27 @@
677677
Table 6 { columns: [t_src, t_dst, t__row_id], primary key: [$0 ASC, $2 ASC], value indices: [0, 1, 2], distribution key: [0], read pk prefix len hint: 1 }
678678
Table 7 { columns: [t_src, t__row_id, _degree], primary key: [$0 ASC, $1 ASC], value indices: [2], distribution key: [0], read pk prefix len hint: 1 }
679679
Table 4294967294 { columns: [p1, p2, p3, t._row_id, t._row_id#1, t.src, t._row_id#2], primary key: [$3 ASC, $4 ASC, $1 ASC, $6 ASC, $5 ASC, $0 ASC], value indices: [0, 1, 2, 3, 4, 5, 6], distribution key: [0], read pk prefix len hint: 6 }
680+
- name: Fix hash join distribution key (https://github.com/risingwavelabs/risingwave/issues/8537)
681+
sql: |
682+
CREATE TABLE part (
683+
p INTEGER,
684+
c VARCHAR,
685+
PRIMARY KEY (p)
686+
);
687+
CREATE TABLE B (
688+
b INTEGER,
689+
d VARCHAR,
690+
PRIMARY KEY (b)
691+
);
692+
select B.* from part join B on part.c = B.d join part p1 on p1.p = part.p and p1.p = B.b;
693+
stream_plan: |
694+
StreamMaterialize { columns: [b, d, part.p(hidden), part.c(hidden), part.p#1(hidden)], pk_columns: [part.p, b, part.c, part.p#1], pk_conflict: "no check" }
695+
└─StreamHashJoin { type: Inner, predicate: part.p = part.p AND b.b = part.p, output: [b.b, b.d, part.p, part.c, part.p] }
696+
├─StreamExchange { dist: HashShard(part.p, b.b) }
697+
| └─StreamHashJoin { type: Inner, predicate: part.c = b.d, output: [part.p, b.b, b.d, part.c] }
698+
| ├─StreamExchange { dist: HashShard(part.c) }
699+
| | └─StreamTableScan { table: part, columns: [part.p, part.c], pk: [part.p], dist: UpstreamHashShard(part.p) }
700+
| └─StreamExchange { dist: HashShard(b.d) }
701+
| └─StreamTableScan { table: b, columns: [b.b, b.d], pk: [b.b], dist: UpstreamHashShard(b.b) }
702+
└─StreamExchange { dist: HashShard(part.p, part.p) }
703+
└─StreamTableScan { table: part, columns: [part.p], pk: [part.p], dist: UpstreamHashShard(part.p) }

src/frontend/src/optimizer/plan_node/batch_hash_join.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ impl ToDistributedBatch for BatchHashJoin {
173173
let r2l = self
174174
.eq_join_predicate()
175175
.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
176-
let l2r = r2l.inverse();
176+
let l2r = self
177+
.eq_join_predicate()
178+
.l2r_eq_columns_mapping(left.schema().len());
177179

178180
let right_dist = right.distribution();
179181
match right_dist {

src/frontend/src/optimizer/plan_node/eq_join_predicate.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ impl EqJoinPredicate {
227227
ColIndexMapping::new(map)
228228
}
229229

230+
/// return the eq columns index mapping from left inputs to right inputs
231+
pub fn l2r_eq_columns_mapping(&self, left_cols_num: usize) -> ColIndexMapping {
232+
let mut map = vec![None; left_cols_num];
233+
for (left, right, _) in self.eq_keys() {
234+
map[left.index] = Some(right.index - left_cols_num);
235+
}
236+
ColIndexMapping::new(map)
237+
}
238+
230239
/// Reorder the `eq_keys` according to the `reorder_idx`.
231240
pub fn reorder(self, reorder_idx: &[usize]) -> Self {
232241
assert!(reorder_idx.len() <= self.eq_keys.len());

src/frontend/src/optimizer/plan_node/logical_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ impl LogicalJoin {
891891
let mut left = self.left();
892892

893893
let r2l = predicate.r2l_eq_columns_mapping(left.schema().len(), right.schema().len());
894-
let l2r = r2l.inverse();
894+
let l2r = predicate.l2r_eq_columns_mapping(left.schema().len());
895895

896896
let right_dist = right.distribution();
897897
match right_dist {

0 commit comments

Comments
 (0)