Skip to content

Commit 3b02bc4

Browse files
authored
Add fleets property to run configurations and CLI (#2488)
Closes: #1447
1 parent c89f24f commit 3b02bc4

File tree

9 files changed

+63
-13
lines changed

9 files changed

+63
-13
lines changed

src/dstack/_internal/cli/services/profile.py

+9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ def register_profile_args(parser: argparse.ArgumentParser):
6565
)
6666

6767
fleets_group = parser.add_argument_group("Fleets")
68+
fleets_group.add_argument(
69+
"--fleet",
70+
action="append",
71+
metavar="NAME",
72+
dest="fleets",
73+
help="Consider only instances from the specified fleet(s) for reuse",
74+
)
6875
fleets_group_exc = fleets_group.add_mutually_exclusive_group()
6976
fleets_group_exc.add_argument(
7077
"-R",
@@ -147,6 +154,8 @@ def apply_profile_args(
147154
if args.max_duration is not None:
148155
profile_settings.max_duration = args.max_duration
149156

157+
if args.fleets:
158+
profile_settings.fleets = args.fleets
150159
if args.idle_duration is not None:
151160
profile_settings.idle_duration = args.idle_duration
152161
elif args.dont_destroy:

src/dstack/_internal/core/models/profiles.py

+3
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ class ProfileParams(CoreModel):
240240
Optional[UtilizationPolicy],
241241
Field(description="Run termination policy based on utilization"),
242242
] = None
243+
fleets: Annotated[
244+
Optional[list[str]], Field(description="The fleets considered for reuse")
245+
] = None
243246

244247
# Deprecated and unused. Left for compatibility with 0.18 clients.
245248
pool_name: Annotated[Optional[str], Field(exclude=True)] = None

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
212212
InstanceModel.deleted == False,
213213
InstanceModel.total_blocks > InstanceModel.busy_blocks,
214214
)
215+
.options(joinedload(InstanceModel.fleet))
215216
.execution_options(populate_existing=True)
216217
)
217218
pool_instances = list(res.unique().scalars().all())

src/dstack/_internal/server/services/instances.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def filter_pool_instances(
181181
continue
182182
if instance.unreachable:
183183
continue
184+
fleet = instance.fleet
185+
if profile.fleets is not None and (fleet is None or fleet.name not in profile.fleets):
186+
continue
184187
if status is not None and instance.status != status:
185188
continue
186189
jpd = get_instance_provisioning_data(instance)
@@ -268,10 +271,12 @@ async def get_pool_instances(
268271
project: ProjectModel,
269272
) -> List[InstanceModel]:
270273
res = await session.execute(
271-
select(InstanceModel).where(
274+
select(InstanceModel)
275+
.where(
272276
InstanceModel.project_id == project.id,
273277
InstanceModel.deleted == False,
274278
)
279+
.options(joinedload(InstanceModel.fleet))
275280
)
276281
instance_models = list(res.unique().scalars().all())
277282
return instance_models

src/dstack/api/server/_fleets.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,9 @@ def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
6464
spec_excludes: Dict[str, Any] = {}
6565
configuration_excludes: Dict[str, Any] = {}
6666
profile_excludes: set[str] = set()
67-
# Fields can be excluded like this:
68-
# if fleet_spec.configuration.availability_zones is None:
69-
# configuration_excludes["availability_zones"] = True
70-
# if fleet_spec.profile is not None and fleet_spec.profile.availability_zones is None:
71-
# profile_excludes.add("availability_zones")
67+
profile = fleet_spec.profile
68+
if profile.fleets is None:
69+
profile_excludes.add("fleet")
7270
if configuration_excludes:
7371
spec_excludes["configuration"] = configuration_excludes
7472
if profile_excludes:

src/dstack/api/server/_runs.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,12 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
104104
spec_excludes: dict[str, Any] = {}
105105
configuration_excludes: dict[str, Any] = {}
106106
profile_excludes: set[str] = set()
107-
# configuration = run_spec.configuration
108-
# profile = run_spec.profile
109-
# Fields can be excluded like this:
110-
# if configuration.availability_zones is None:
111-
# configuration_excludes["availability_zones"] = True
112-
# if profile is not None and profile.availability_zones is None:
113-
# profile_excludes.add("availability_zones")
107+
configuration = run_spec.configuration
108+
profile = run_spec.profile
109+
if configuration.fleets is None:
110+
configuration_excludes["fleet"] = True
111+
if profile is not None and profile.fleets is None:
112+
profile_excludes.add("fleet")
114113
if configuration_excludes:
115114
spec_excludes["configuration"] = configuration_excludes
116115
if profile_excludes:

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

+29
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,35 @@ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSessi
536536
assert instance.total_blocks == 4
537537
assert instance.busy_blocks == 2
538538

539+
@pytest.mark.asyncio
540+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
541+
async def test_assigns_job_to_specific_fleet(self, test_db, session: AsyncSession):
542+
project = await create_project(session)
543+
user = await create_user(session)
544+
repo = await create_repo(session=session, project_id=project.id)
545+
fleet_a = await create_fleet(session=session, project=project, name="a")
546+
await create_instance(session=session, project=project, fleet=fleet_a, price=1.0)
547+
fleet_b = await create_fleet(session=session, project=project, name="b")
548+
await create_instance(session=session, project=project, fleet=fleet_b, price=2.0)
549+
fleet_c = await create_fleet(session=session, project=project, name="c")
550+
await create_instance(session=session, project=project, fleet=fleet_c, price=3.0)
551+
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
552+
# When more than one fleet is requested, the cheapest one is selected
553+
run_spec.configuration.fleets = ["c", "b"]
554+
run = await create_run(
555+
session=session, project=project, repo=repo, user=user, run_spec=run_spec
556+
)
557+
job = await create_job(session=session, run=run)
558+
559+
await process_submitted_jobs()
560+
561+
await session.refresh(job)
562+
res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
563+
job = res.unique().scalar_one()
564+
assert job.status == JobStatus.SUBMITTED
565+
assert job.instance is not None
566+
assert job.instance.fleet == fleet_b
567+
539568
@pytest.mark.asyncio
540569
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
541570
async def test_creates_new_instance_in_existing_fleet(self, test_db, session: AsyncSession):

src/tests/_internal/server/routers/test_fleets.py

+2
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
366366
"name": "",
367367
"default": False,
368368
"reservation": None,
369+
"fleets": None,
369370
},
370371
"autocreated": False,
371372
},
@@ -484,6 +485,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
484485
"name": "",
485486
"default": False,
486487
"reservation": None,
488+
"fleets": None,
487489
},
488490
"autocreated": False,
489491
},

src/tests/_internal/server/routers/test_runs.py

+4
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def get_dev_env_run_plan_dict(
124124
"idle_duration": None,
125125
"utilization_policy": None,
126126
"reservation": None,
127+
"fleets": None,
127128
},
128129
"configuration_path": "dstack.yaml",
129130
"profile": {
@@ -142,6 +143,7 @@ def get_dev_env_run_plan_dict(
142143
"idle_duration": None,
143144
"utilization_policy": None,
144145
"reservation": None,
146+
"fleets": None,
145147
},
146148
"repo_code_hash": None,
147149
"repo_data": {"repo_dir": "/repo", "repo_type": "local"},
@@ -274,6 +276,7 @@ def get_dev_env_run_dict(
274276
"idle_duration": None,
275277
"utilization_policy": None,
276278
"reservation": None,
279+
"fleets": None,
277280
},
278281
"configuration_path": "dstack.yaml",
279282
"profile": {
@@ -292,6 +295,7 @@ def get_dev_env_run_dict(
292295
"idle_duration": None,
293296
"utilization_policy": None,
294297
"reservation": None,
298+
"fleets": None,
295299
},
296300
"repo_code_hash": None,
297301
"repo_data": {"repo_dir": "/repo", "repo_type": "local"},

0 commit comments

Comments
 (0)