Skip to content

Commit fa8b5e3

Browse files
authored
feat: add target_parallelism_per_worker field for scaling (risingwavelabs#11945)
1 parent deac61e commit fa8b5e3

File tree

5 files changed

+105
-27
lines changed

5 files changed

+105
-27
lines changed

proto/meta.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ message GetReschedulePlanRequest {
431431
repeated uint32 include_worker_ids = 1;
432432
repeated uint32 exclude_worker_ids = 2;
433433
optional uint32 target_parallelism = 3;
434+
optional uint32 target_parallelism_per_worker = 4;
434435
}
435436

436437
message StableResizePolicy {

src/ctl/src/cmd_impl/scale/resize.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
120120
exclude_workers,
121121
include_workers,
122122
target_parallelism,
123+
target_parallelism_per_worker,
123124
generate,
124125
output,
125126
yes,
@@ -132,8 +133,15 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
132133
let include_worker_ids =
133134
worker_input_to_worker_ids(include_workers.unwrap_or_default(), true);
134135

135-
if let Some(target) = target_parallelism && target == 0 {
136-
fail!("Target parallelism must be greater than 0");
136+
match (target_parallelism, target_parallelism_per_worker) {
137+
(Some(_), Some(_)) => {
138+
fail!("Cannot specify both target parallelism and target parallelism per worker")
139+
}
140+
(_, Some(_)) if include_worker_ids.is_empty() => {
141+
fail!("Cannot specify target parallelism per worker without including any worker")
142+
}
143+
(Some(target), _) if target == 0 => fail!("Target parallelism must be greater than 0"),
144+
_ => {}
137145
}
138146

139147
for worker_id in exclude_worker_ids.iter().chain(include_worker_ids.iter()) {
@@ -161,6 +169,7 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
161169
include_worker_ids,
162170
exclude_worker_ids,
163171
target_parallelism,
172+
target_parallelism_per_worker,
164173
}
165174
};
166175

src/ctl/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ pub struct ScaleResizeCommands {
288288
#[clap(long)]
289289
target_parallelism: Option<u32>,
290290

291+
/// The target parallelism per worker, conflicts with `target_parallelism`, requires
292+
/// `include_workers` to be set.
293+
#[clap(
294+
long,
295+
requires = "include_workers",
296+
conflicts_with = "target_parallelism"
297+
)]
298+
target_parallelism_per_worker: Option<u32>,
299+
291300
/// Will generate a plan supported by the `reschedule` command and save it to the provided path
292301
/// by the `--output`.
293302
#[clap(long, default_value_t = false)]

src/meta/src/stream/scale.rs

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,7 @@ where
16891689
include_worker_ids: BTreeSet<WorkerId>,
16901690
exclude_worker_ids: BTreeSet<WorkerId>,
16911691
target_parallelism: Option<usize>,
1692+
target_parallelism_per_worker: Option<usize>,
16921693
}
16931694

16941695
let mut fragment_worker_changes: HashMap<_, _> = fragment_worker_changes
@@ -1700,6 +1701,9 @@ where
17001701
include_worker_ids: changes.include_worker_ids.into_iter().collect(),
17011702
exclude_worker_ids: changes.exclude_worker_ids.into_iter().collect(),
17021703
target_parallelism: changes.target_parallelism.map(|p| p as usize),
1704+
target_parallelism_per_worker: changes
1705+
.target_parallelism_per_worker
1706+
.map(|p| p as usize),
17031707
},
17041708
)
17051709
})
@@ -1718,6 +1722,7 @@ where
17181722
include_worker_ids,
17191723
exclude_worker_ids,
17201724
target_parallelism,
1725+
target_parallelism_per_worker,
17211726
},
17221727
) in fragment_worker_changes
17231728
{
@@ -1757,19 +1762,69 @@ where
17571762
})
17581763
.collect();
17591764

1765+
let include_worker_parallel_unit_ids = include_worker_ids
1766+
.iter()
1767+
.flat_map(|worker_id| worker_parallel_units.get(worker_id).unwrap())
1768+
.cloned()
1769+
.collect_vec();
1770+
1771+
let exclude_worker_parallel_unit_ids = exclude_worker_ids
1772+
.iter()
1773+
.flat_map(|worker_id| worker_parallel_units.get(worker_id).unwrap())
1774+
.cloned()
1775+
.collect_vec();
1776+
1777+
fn refilter_parallel_unit_id_by_target_parallelism(
1778+
worker_parallel_units: &HashMap<u32, HashSet<ParallelUnitId>>,
1779+
include_worker_ids: &BTreeSet<WorkerId>,
1780+
include_worker_parallel_unit_ids: &[ParallelUnitId],
1781+
target_parallel_unit_ids: &mut BTreeSet<ParallelUnitId>,
1782+
target_parallelism_per_worker: usize,
1783+
) {
1784+
let limited_worker_parallel_unit_ids = include_worker_ids
1785+
.iter()
1786+
.flat_map(|worker_id| {
1787+
worker_parallel_units
1788+
.get(worker_id)
1789+
.cloned()
1790+
.unwrap()
1791+
.into_iter()
1792+
.sorted()
1793+
.take(target_parallelism_per_worker)
1794+
})
1795+
.collect_vec();
1796+
1797+
// remove all the parallel units in the limited workers
1798+
target_parallel_unit_ids
1799+
.retain(|id| !include_worker_parallel_unit_ids.contains(id));
1800+
1801+
// then we re-add the limited parallel units from the limited workers
1802+
target_parallel_unit_ids.extend(limited_worker_parallel_unit_ids.into_iter());
1803+
}
1804+
17601805
match fragment.get_distribution_type().unwrap() {
17611806
FragmentDistributionType::Unspecified => unreachable!(),
17621807
FragmentDistributionType::Single => {
17631808
let single_parallel_unit_id =
17641809
fragment_parallel_unit_ids.iter().exactly_one().unwrap();
17651810

1766-
let target_parallel_unit_ids: BTreeSet<_> = worker_parallel_units
1811+
let mut target_parallel_unit_ids: BTreeSet<_> = worker_parallel_units
17671812
.keys()
17681813
.filter(|id| !unschedulable_worker_ids.contains(*id))
17691814
.filter(|id| !exclude_worker_ids.contains(*id))
17701815
.flat_map(|id| worker_parallel_units.get(id).cloned().unwrap())
17711816
.collect();
17721817

1818+
if let Some(target_parallelism_per_worker) = target_parallelism_per_worker {
1819+
refilter_parallel_unit_id_by_target_parallelism(
1820+
&worker_parallel_units,
1821+
&include_worker_ids,
1822+
&include_worker_parallel_unit_ids,
1823+
&mut target_parallel_unit_ids,
1824+
target_parallelism_per_worker,
1825+
);
1826+
}
1827+
17731828
if target_parallel_unit_ids.is_empty() {
17741829
bail!(
17751830
"No schedulable ParallelUnits available for single distribution fragment {}",
@@ -1796,18 +1851,6 @@ where
17961851
}
17971852
}
17981853
FragmentDistributionType::Hash => {
1799-
let include_worker_parallel_unit_ids = include_worker_ids
1800-
.iter()
1801-
.flat_map(|worker_id| worker_parallel_units.get(worker_id).unwrap())
1802-
.cloned()
1803-
.collect_vec();
1804-
1805-
let exclude_worker_parallel_unit_ids = exclude_worker_ids
1806-
.iter()
1807-
.flat_map(|worker_id| worker_parallel_units.get(worker_id).unwrap())
1808-
.cloned()
1809-
.collect_vec();
1810-
18111854
let mut target_parallel_unit_ids: BTreeSet<_> =
18121855
fragment_parallel_unit_ids.clone();
18131856
target_parallel_unit_ids.extend(include_worker_parallel_unit_ids.iter());
@@ -1821,15 +1864,30 @@ where
18211864
);
18221865
}
18231866

1824-
if let Some(target_parallelism) = target_parallelism {
1825-
if target_parallel_unit_ids.len() < target_parallelism {
1826-
bail!("Target parallelism {} is greater than schedulable ParallelUnits {}", target_parallelism, target_parallel_unit_ids.len());
1867+
match (target_parallelism, target_parallelism_per_worker) {
1868+
(Some(_), Some(_)) => {
1869+
bail!("Cannot specify both target parallelism and target parallelism per worker");
18271870
}
1871+
(Some(target_parallelism), _) => {
1872+
if target_parallel_unit_ids.len() < target_parallelism {
1873+
bail!("Target parallelism {} is greater than schedulable ParallelUnits {}", target_parallelism, target_parallel_unit_ids.len());
1874+
}
18281875

1829-
target_parallel_unit_ids = target_parallel_unit_ids
1830-
.into_iter()
1831-
.take(target_parallelism)
1832-
.collect();
1876+
target_parallel_unit_ids = target_parallel_unit_ids
1877+
.into_iter()
1878+
.take(target_parallelism)
1879+
.collect();
1880+
}
1881+
(_, Some(target_parallelism_per_worker)) => {
1882+
refilter_parallel_unit_id_by_target_parallelism(
1883+
&worker_parallel_units,
1884+
&include_worker_ids,
1885+
&include_worker_parallel_unit_ids,
1886+
&mut target_parallel_unit_ids,
1887+
target_parallelism_per_worker,
1888+
);
1889+
}
1890+
_ => {}
18331891
}
18341892

18351893
let to_expand_parallel_units = target_parallel_unit_ids

src/tests/simulation/tests/integration_tests/scale/plan.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16+
use std::default::Default;
1617

1718
use anyhow::Result;
1819
use itertools::Itertools;
@@ -65,7 +66,7 @@ async fn test_resize_normal() -> Result<()> {
6566
WorkerChanges {
6667
include_worker_ids: vec![],
6768
exclude_worker_ids: removed_workers,
68-
target_parallelism: None,
69+
..Default::default()
6970
},
7071
)]),
7172
}))
@@ -148,7 +149,7 @@ async fn test_resize_single() -> Result<()> {
148149
WorkerChanges {
149150
include_worker_ids: vec![],
150151
exclude_worker_ids: vec![prev_worker.id],
151-
target_parallelism: None,
152+
..Default::default()
152153
},
153154
)]),
154155
}))
@@ -223,15 +224,15 @@ async fn test_resize_single_failed() -> Result<()> {
223224
WorkerChanges {
224225
include_worker_ids: vec![],
225226
exclude_worker_ids: vec![worker_a.id],
226-
target_parallelism: None,
227+
..Default::default()
227228
},
228229
),
229230
(
230231
downstream_fragment_id,
231232
WorkerChanges {
232233
include_worker_ids: vec![],
233234
exclude_worker_ids: vec![worker_b.id],
234-
target_parallelism: None,
235+
..Default::default()
235236
},
236237
),
237238
]),
@@ -302,7 +303,7 @@ join mv5 on mv1.v = mv5.v;",
302303
WorkerChanges {
303304
include_worker_ids: vec![],
304305
exclude_worker_ids: removed_worker_ids,
305-
target_parallelism: None,
306+
..Default::default()
306307
},
307308
)]),
308309
}))

0 commit comments

Comments
 (0)