Skip to content

Commit 530965e

Browse files
committed
add order-preserving probes
1 parent 5095612 commit 530965e

File tree

6 files changed

+172
-55
lines changed

6 files changed

+172
-55
lines changed

crates/polars-expr/src/chunked_idx_table/row_encoded.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ impl RowEncodedChunkedIdxTable {
7070
probe_match: &mut Vec<IdxSize>,
7171
limit: IdxSize,
7272
) -> IdxSize {
73-
table_match.clear();
74-
probe_match.clear();
75-
7673
let mut keys_processed = 0;
7774
for (key_idx, hash, key) in hash_keys {
7875
let found_match = if let Some(key) = key {

crates/polars-expr/src/hash_keys.rs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,25 @@ impl HashKeys {
7474
self.len() == 0
7575
}
7676

77+
/// After this call partitions will be extended with the partition for each
78+
/// hash. Nulls are assigned IdxSize::MAX or a specific partition depending
79+
/// on whether partition_nulls is true.
80+
pub fn gen_partitions(
81+
&self,
82+
partitioner: &HashPartitioner,
83+
partitions: &mut Vec<IdxSize>,
84+
partition_nulls: bool,
85+
) {
86+
match self {
87+
Self::RowEncoded(s) => s.gen_partitions(partitioner, partitions, partition_nulls),
88+
Self::Single(s) => s.gen_partitions(partitioner, partitions, partition_nulls),
89+
}
90+
}
91+
7792
/// After this call partition_idxs[p] will be extended with the indices of
7893
/// hashes that belong to partition p, and the cardinality sketches are
7994
/// updated accordingly.
80-
pub fn gen_partition_idxs(
95+
pub fn gen_idxs_per_partition(
8196
&self,
8297
partitioner: &HashPartitioner,
8398
partition_idxs: &mut [Vec<IdxSize>],
@@ -86,13 +101,13 @@ impl HashKeys {
86101
) {
87102
if sketches.is_empty() {
88103
match self {
89-
Self::RowEncoded(s) => s.gen_partition_idxs::<false>(
104+
Self::RowEncoded(s) => s.gen_idxs_per_partition::<false>(
90105
partitioner,
91106
partition_idxs,
92107
sketches,
93108
partition_nulls,
94109
),
95-
Self::Single(s) => s.gen_partition_idxs::<false>(
110+
Self::Single(s) => s.gen_idxs_per_partition::<false>(
96111
partitioner,
97112
partition_idxs,
98113
sketches,
@@ -101,13 +116,13 @@ impl HashKeys {
101116
}
102117
} else {
103118
match self {
104-
Self::RowEncoded(s) => s.gen_partition_idxs::<true>(
119+
Self::RowEncoded(s) => s.gen_idxs_per_partition::<true>(
105120
partitioner,
106121
partition_idxs,
107122
sketches,
108123
partition_nulls,
109124
),
110-
Self::Single(s) => s.gen_partition_idxs::<true>(
125+
Self::Single(s) => s.gen_idxs_per_partition::<true>(
111126
partitioner,
112127
partition_idxs,
113128
sketches,
@@ -159,7 +174,33 @@ pub struct RowEncodedKeys {
159174
}
160175

161176
impl RowEncodedKeys {
162-
pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
177+
pub fn gen_partitions(
178+
&self,
179+
partitioner: &HashPartitioner,
180+
partitions: &mut Vec<IdxSize>,
181+
partition_nulls: bool,
182+
) {
183+
partitions.reserve(self.hashes.len());
184+
if let Some(validity) = self.keys.validity() {
185+
// Arbitrarily put nulls in partition 0.
186+
let null_p = if partition_nulls { 0 } else { IdxSize::MAX };
187+
partitions.extend(self.hashes.values_iter().zip(validity).map(|(h, is_v)| {
188+
if is_v {
189+
partitioner.hash_to_partition(*h) as IdxSize
190+
} else {
191+
null_p
192+
}
193+
}))
194+
} else {
195+
partitions.extend(
196+
self.hashes
197+
.values_iter()
198+
.map(|h| partitioner.hash_to_partition(*h) as IdxSize),
199+
)
200+
}
201+
}
202+
203+
pub fn gen_idxs_per_partition<const BUILD_SKETCHES: bool>(
163204
&self,
164205
partitioner: &HashPartitioner,
165206
partition_idxs: &mut [Vec<IdxSize>],
@@ -261,7 +302,16 @@ pub struct SingleKeys {
261302
}
262303

263304
impl SingleKeys {
264-
pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
305+
pub fn gen_partitions(
306+
&self,
307+
_partitioner: &HashPartitioner,
308+
_partitions: &mut Vec<IdxSize>,
309+
_partition_nulls: bool,
310+
) {
311+
todo!()
312+
}
313+
314+
pub fn gen_idxs_per_partition<const BUILD_SKETCHES: bool>(
265315
&self,
266316
partitioner: &HashPartitioner,
267317
partition_idxs: &mut [Vec<IdxSize>],

crates/polars-expr/src/idx_table/row_encoded.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ impl RowEncodedIdxTable {
6868
probe_match: &mut Vec<IdxSize>,
6969
limit: IdxSize,
7070
) -> IdxSize {
71-
table_match.clear();
72-
probe_match.clear();
73-
7471
let mut keys_processed = 0;
7572
for (key_idx, hash, key) in hash_keys {
7673
let found_match = if let Some(key) = key {

crates/polars-stream/src/nodes/joins/equi_join.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ impl BuildState {
503503
for p in partition_idxs.iter_mut() {
504504
p.clear();
505505
}
506-
hash_keys.gen_partition_idxs(
506+
hash_keys.gen_idxs_per_partition(
507507
&partitioner,
508508
&mut partition_idxs,
509509
&mut sketches,
@@ -678,7 +678,7 @@ impl ProbeState {
678678
for p in partition_idxs.iter_mut() {
679679
p.clear();
680680
}
681-
hash_keys.gen_partition_idxs(
681+
hash_keys.gen_idxs_per_partition(
682682
&partitioner,
683683
&mut partition_idxs,
684684
&mut [],
@@ -690,6 +690,8 @@ impl ProbeState {
690690
let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions());
691691
let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX");
692692
for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {
693+
table_match.clear();
694+
probe_match.clear();
693695
p.hash_table.probe_subset(
694696
&hash_keys,
695697
idxs_in_p,
@@ -759,6 +761,8 @@ impl ProbeState {
759761
for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {
760762
let mut offset = 0;
761763
while offset < idxs_in_p.len() {
764+
table_match.clear();
765+
probe_match.clear();
762766
offset += p.hash_table.probe_subset(
763767
&hash_keys,
764768
&idxs_in_p[offset..],

crates/polars-stream/src/nodes/joins/new_equi_join.rs

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ impl BuildState {
516516
let mut payload = select_payload(morsel.df().clone(), payload_selector);
517517
payload.rechunk_mut();
518518

519-
hash_keys.gen_partition_idxs(
519+
hash_keys.gen_idxs_per_partition(
520520
&partitioner,
521521
&mut local.morsel_idxs_values_per_p,
522522
&mut local.sketch_per_p,
@@ -644,6 +644,8 @@ impl ProbeState {
644644
) -> PolarsResult<MorselSeq> {
645645
// TODO: shuffle after partitioning and keep probe tables thread-local.
646646
let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()];
647+
let mut probe_partitions = Vec::new();
648+
let mut materialized_idxsize_range = Vec::new();
647649
let mut table_match = Vec::new();
648650
let mut probe_match = Vec::new();
649651
let mut max_seq = MorselSeq::default();
@@ -690,79 +692,142 @@ impl ProbeState {
690692
let max_match_per_key_est = selectivity_estimate as usize + 16;
691693
let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize).min(probe_limit as usize);
692694
build_out.reserve(out_est_size + max_match_per_key_est);
693-
probe_out.reserve(out_est_size + max_match_per_key_est);
694695

695696
unsafe {
696-
// Partition and probe the tables.
697-
for p in partition_idxs.iter_mut() {
698-
p.clear();
699-
}
700-
hash_keys.gen_partition_idxs(
701-
&partitioner,
702-
&mut partition_idxs,
703-
&mut [],
704-
emit_unmatched,
705-
);
697+
let new_morsel = |build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| {
698+
let mut build_df = build.freeze_reset();
699+
let mut probe_df = probe.freeze_reset();
700+
let out_df = if params.left_is_build.unwrap() {
701+
build_df.hstack_mut_unchecked(probe_df.get_columns());
702+
build_df
703+
} else {
704+
probe_df.hstack_mut_unchecked(build_df.get_columns());
705+
probe_df
706+
};
707+
let out_df = postprocess_join(out_df, params);
708+
Morsel::new(out_df, seq, src_token.clone())
709+
};
710+
706711
if params.preserve_order_probe {
707-
todo!()
708-
} else {
709-
let new_morsel = |mut build_df: DataFrame, mut probe_df: DataFrame| {
710-
let out_df = if params.left_is_build.unwrap() {
711-
build_df.hstack_mut_unchecked(probe_df.get_columns());
712-
build_df
713-
} else {
714-
probe_df.hstack_mut_unchecked(build_df.get_columns());
715-
probe_df
712+
// To preserve the order we can't do bulk probes per partition and must follow
713+
// the order of the probe morsel. We can still group probes that are
714+
// consecutively on the same partition.
715+
hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched);
716+
let mut probe_group_start = 0;
717+
while probe_group_start < probe_partitions.len() {
718+
let p_idx = probe_partitions[probe_group_start];
719+
let mut probe_group_end = probe_group_start + 1;
720+
while probe_partitions.get(probe_group_end) == Some(&p_idx) {
721+
probe_group_end += 1;
722+
}
723+
let Some(p) = partitions.get(p_idx as usize) else {
724+
probe_group_start = probe_group_end;
725+
continue;
716726
};
717-
let out_df = postprocess_join(out_df, params);
718-
Morsel::new(out_df, seq, src_token.clone())
719-
};
727+
728+
materialized_idxsize_range.extend(materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize);
729+
730+
while probe_group_start < probe_group_end {
731+
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
732+
table_match.clear();
733+
probe_group_start += p.hash_table.probe_subset(
734+
&hash_keys,
735+
&materialized_idxsize_range[probe_group_start..probe_group_end],
736+
&mut table_match,
737+
&mut probe_match,
738+
mark_matches,
739+
emit_unmatched,
740+
matches_before_limit,
741+
) as usize;
742+
743+
if emit_unmatched {
744+
build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always);
745+
} else {
746+
build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always);
747+
};
748+
749+
if probe_match.len() >= probe_limit as usize || probe_group_start == probe_partitions.len() {
750+
if !payload_rechunked {
751+
payload.rechunk_mut();
752+
payload_rechunked = true;
753+
}
754+
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
755+
probe_match.clear();
756+
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
757+
if send.send(out_morsel).await.is_err() {
758+
return Ok(max_seq);
759+
}
760+
if probe_group_end != probe_partitions.len() {
761+
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
762+
// matches and just do a large reserve.
763+
build_out.reserve(probe_limit as usize + max_match_per_key_est);
764+
}
765+
}
766+
}
767+
}
768+
} else {
769+
// Partition and probe the tables.
770+
for p in partition_idxs.iter_mut() {
771+
p.clear();
772+
}
773+
hash_keys.gen_idxs_per_partition(
774+
&partitioner,
775+
&mut partition_idxs,
776+
&mut [],
777+
emit_unmatched,
778+
);
720779

721780
for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {
722781
let mut offset = 0;
723782
while offset < idxs_in_p.len() {
783+
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
784+
table_match.clear();
724785
offset += p.hash_table.probe_subset(
725786
&hash_keys,
726787
&idxs_in_p[offset..],
727788
&mut table_match,
728789
&mut probe_match,
729790
mark_matches,
730791
emit_unmatched,
731-
probe_limit - probe_out.len() as IdxSize,
792+
matches_before_limit,
732793
) as usize;
733794

734-
if probe_match.is_empty() {
795+
if table_match.is_empty() {
735796
continue;
736797
}
737-
total_matches += probe_match.len();
798+
total_matches += table_match.len();
738799

739-
// Gather output and send.
740800
if emit_unmatched {
741801
build_out.opt_gather_extend(&p.payload, &table_match, ShareStrategy::Always);
742802
} else {
743803
build_out.gather_extend(&p.payload, &table_match, ShareStrategy::Always);
744804
};
745-
if !payload_rechunked {
746-
payload.rechunk_mut();
747-
payload_rechunked = true;
748-
}
749-
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
750805

751-
if probe_out.len() >= probe_limit as usize {
752-
let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset());
806+
if probe_match.len() >= probe_limit as usize {
807+
if !payload_rechunked {
808+
payload.rechunk_mut();
809+
payload_rechunked = true;
810+
}
811+
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
812+
probe_match.clear();
813+
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
753814
if send.send(out_morsel).await.is_err() {
754815
return Ok(max_seq);
755816
}
756817
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
757818
// matches and just do a large reserve.
758819
build_out.reserve(probe_limit as usize + max_match_per_key_est);
759-
probe_out.reserve(probe_limit as usize + max_match_per_key_est);
760820
}
761821
}
762822
}
763823

764-
if !probe_out.is_empty() {
765-
let out_morsel = new_morsel(build_out.freeze_reset(), probe_out.freeze_reset());
824+
if !probe_match.is_empty() {
825+
if !payload_rechunked {
826+
payload.rechunk_mut();
827+
}
828+
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
829+
probe_match.clear();
830+
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
766831
if send.send(out_morsel).await.is_err() {
767832
return Ok(max_seq);
768833
}

crates/polars-stream/src/physical_plan/to_graph.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,8 +809,12 @@ fn to_graph_rec<'a>(
809809
.map(|e| create_stream_expr(e, ctx, &right_input_schema))
810810
.try_collect_vec()?;
811811

812-
// TODO: implement order-maintaining join in new join impl.
813-
if args.maintain_order == MaintainOrderJoin::None {
812+
// TODO: implement build-side order-maintaining join in new join impl.
813+
let preserve_order_build = matches!(
814+
args.maintain_order,
815+
MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft
816+
);
817+
if !preserve_order_build {
814818
ctx.graph.add_node(
815819
nodes::joins::new_equi_join::EquiJoinNode::new(
816820
left_input_schema,

0 commit comments

Comments
 (0)