@@ -1689,6 +1689,7 @@ where
1689
1689
include_worker_ids : BTreeSet < WorkerId > ,
1690
1690
exclude_worker_ids : BTreeSet < WorkerId > ,
1691
1691
target_parallelism : Option < usize > ,
1692
+ target_parallelism_per_worker : Option < usize > ,
1692
1693
}
1693
1694
1694
1695
let mut fragment_worker_changes: HashMap < _ , _ > = fragment_worker_changes
@@ -1700,6 +1701,9 @@ where
1700
1701
include_worker_ids : changes. include_worker_ids . into_iter ( ) . collect ( ) ,
1701
1702
exclude_worker_ids : changes. exclude_worker_ids . into_iter ( ) . collect ( ) ,
1702
1703
target_parallelism : changes. target_parallelism . map ( |p| p as usize ) ,
1704
+ target_parallelism_per_worker : changes
1705
+ . target_parallelism_per_worker
1706
+ . map ( |p| p as usize ) ,
1703
1707
} ,
1704
1708
)
1705
1709
} )
@@ -1718,6 +1722,7 @@ where
1718
1722
include_worker_ids,
1719
1723
exclude_worker_ids,
1720
1724
target_parallelism,
1725
+ target_parallelism_per_worker,
1721
1726
} ,
1722
1727
) in fragment_worker_changes
1723
1728
{
@@ -1757,19 +1762,69 @@ where
1757
1762
} )
1758
1763
. collect ( ) ;
1759
1764
1765
+ let include_worker_parallel_unit_ids = include_worker_ids
1766
+ . iter ( )
1767
+ . flat_map ( |worker_id| worker_parallel_units. get ( worker_id) . unwrap ( ) )
1768
+ . cloned ( )
1769
+ . collect_vec ( ) ;
1770
+
1771
+ let exclude_worker_parallel_unit_ids = exclude_worker_ids
1772
+ . iter ( )
1773
+ . flat_map ( |worker_id| worker_parallel_units. get ( worker_id) . unwrap ( ) )
1774
+ . cloned ( )
1775
+ . collect_vec ( ) ;
1776
+
1777
+ fn refilter_parallel_unit_id_by_target_parallelism (
1778
+ worker_parallel_units : & HashMap < u32 , HashSet < ParallelUnitId > > ,
1779
+ include_worker_ids : & BTreeSet < WorkerId > ,
1780
+ include_worker_parallel_unit_ids : & [ ParallelUnitId ] ,
1781
+ target_parallel_unit_ids : & mut BTreeSet < ParallelUnitId > ,
1782
+ target_parallelism_per_worker : usize ,
1783
+ ) {
1784
+ let limited_worker_parallel_unit_ids = include_worker_ids
1785
+ . iter ( )
1786
+ . flat_map ( |worker_id| {
1787
+ worker_parallel_units
1788
+ . get ( worker_id)
1789
+ . cloned ( )
1790
+ . unwrap ( )
1791
+ . into_iter ( )
1792
+ . sorted ( )
1793
+ . take ( target_parallelism_per_worker)
1794
+ } )
1795
+ . collect_vec ( ) ;
1796
+
1797
+ // remove all the parallel units in the limited workers
1798
+ target_parallel_unit_ids
1799
+ . retain ( |id| !include_worker_parallel_unit_ids. contains ( id) ) ;
1800
+
1801
+ // then we re-add the limited parallel units from the limited workers
1802
+ target_parallel_unit_ids. extend ( limited_worker_parallel_unit_ids. into_iter ( ) ) ;
1803
+ }
1804
+
1760
1805
match fragment. get_distribution_type ( ) . unwrap ( ) {
1761
1806
FragmentDistributionType :: Unspecified => unreachable ! ( ) ,
1762
1807
FragmentDistributionType :: Single => {
1763
1808
let single_parallel_unit_id =
1764
1809
fragment_parallel_unit_ids. iter ( ) . exactly_one ( ) . unwrap ( ) ;
1765
1810
1766
- let target_parallel_unit_ids: BTreeSet < _ > = worker_parallel_units
1811
+ let mut target_parallel_unit_ids: BTreeSet < _ > = worker_parallel_units
1767
1812
. keys ( )
1768
1813
. filter ( |id| !unschedulable_worker_ids. contains ( * id) )
1769
1814
. filter ( |id| !exclude_worker_ids. contains ( * id) )
1770
1815
. flat_map ( |id| worker_parallel_units. get ( id) . cloned ( ) . unwrap ( ) )
1771
1816
. collect ( ) ;
1772
1817
1818
+ if let Some ( target_parallelism_per_worker) = target_parallelism_per_worker {
1819
+ refilter_parallel_unit_id_by_target_parallelism (
1820
+ & worker_parallel_units,
1821
+ & include_worker_ids,
1822
+ & include_worker_parallel_unit_ids,
1823
+ & mut target_parallel_unit_ids,
1824
+ target_parallelism_per_worker,
1825
+ ) ;
1826
+ }
1827
+
1773
1828
if target_parallel_unit_ids. is_empty ( ) {
1774
1829
bail ! (
1775
1830
"No schedulable ParallelUnits available for single distribution fragment {}" ,
@@ -1796,18 +1851,6 @@ where
1796
1851
}
1797
1852
}
1798
1853
FragmentDistributionType :: Hash => {
1799
- let include_worker_parallel_unit_ids = include_worker_ids
1800
- . iter ( )
1801
- . flat_map ( |worker_id| worker_parallel_units. get ( worker_id) . unwrap ( ) )
1802
- . cloned ( )
1803
- . collect_vec ( ) ;
1804
-
1805
- let exclude_worker_parallel_unit_ids = exclude_worker_ids
1806
- . iter ( )
1807
- . flat_map ( |worker_id| worker_parallel_units. get ( worker_id) . unwrap ( ) )
1808
- . cloned ( )
1809
- . collect_vec ( ) ;
1810
-
1811
1854
let mut target_parallel_unit_ids: BTreeSet < _ > =
1812
1855
fragment_parallel_unit_ids. clone ( ) ;
1813
1856
target_parallel_unit_ids. extend ( include_worker_parallel_unit_ids. iter ( ) ) ;
@@ -1821,15 +1864,30 @@ where
1821
1864
) ;
1822
1865
}
1823
1866
1824
- if let Some ( target_parallelism) = target_parallelism {
1825
- if target_parallel_unit_ids . len ( ) < target_parallelism {
1826
- bail ! ( "Target parallelism {} is greater than schedulable ParallelUnits {}" , target_parallelism , target_parallel_unit_ids . len ( ) ) ;
1867
+ match ( target_parallelism, target_parallelism_per_worker ) {
1868
+ ( Some ( _ ) , Some ( _ ) ) => {
1869
+ bail ! ( "Cannot specify both target parallelism and target parallelism per worker" ) ;
1827
1870
}
1871
+ ( Some ( target_parallelism) , _) => {
1872
+ if target_parallel_unit_ids. len ( ) < target_parallelism {
1873
+ bail ! ( "Target parallelism {} is greater than schedulable ParallelUnits {}" , target_parallelism, target_parallel_unit_ids. len( ) ) ;
1874
+ }
1828
1875
1829
- target_parallel_unit_ids = target_parallel_unit_ids
1830
- . into_iter ( )
1831
- . take ( target_parallelism)
1832
- . collect ( ) ;
1876
+ target_parallel_unit_ids = target_parallel_unit_ids
1877
+ . into_iter ( )
1878
+ . take ( target_parallelism)
1879
+ . collect ( ) ;
1880
+ }
1881
+ ( _, Some ( target_parallelism_per_worker) ) => {
1882
+ refilter_parallel_unit_id_by_target_parallelism (
1883
+ & worker_parallel_units,
1884
+ & include_worker_ids,
1885
+ & include_worker_parallel_unit_ids,
1886
+ & mut target_parallel_unit_ids,
1887
+ target_parallelism_per_worker,
1888
+ ) ;
1889
+ }
1890
+ _ => { }
1833
1891
}
1834
1892
1835
1893
let to_expand_parallel_units = target_parallel_unit_ids
0 commit comments