From 699ddb07ec24fd1404e50170f0e44bf5acbc5720 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Fri, 22 May 2026 21:52:57 +0200 Subject: [PATCH] Fix Vast.ai offer order in `dstack offer --fleet` Order by score rather than by price, the same way offers are already ordered in apply plans and `dstack offer` without `--fleet`. --- .../server/services/backends/__init__.py | 10 +- .../_internal/server/services/offers.py | 18 ++- .../_internal/server/services/runs/plan.py | 25 +-- src/dstack/_internal/server/testing/common.py | 2 + .../_internal/server/routers/test_runs.py | 148 ++++++++++++++++++ 5 files changed, 184 insertions(+), 19 deletions(-) diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index c37bbd88e7..d3cbcaebaa 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -1,5 +1,4 @@ import asyncio -import heapq import json import time from collections.abc import Iterable, Iterator @@ -43,6 +42,7 @@ from dstack._internal.core.models.runs import Requirements from dstack._internal.server import settings from dstack._internal.server.models import BackendModel, DecryptedString, ProjectModel +from dstack._internal.server.services.offers import merge_offer_iterables from dstack._internal.settings import LOCAL_BACKEND_ENABLED from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -459,7 +459,7 @@ async def get_backend_offers( backends: List[Backend], requirements: Requirements, exclude_not_available: bool = False, -) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]: +) -> Iterable[Tuple[Backend, InstanceOfferWithAvailability]]: """ Yields backend offers satisfying `requirements` sorted by price. """ @@ -474,7 +474,7 @@ def get_filtered_offers_with_backends( logger.debug("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends]) tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends] - offers_by_backend = [] + offers_by_backend: list[Iterable[tuple[Backend, InstanceOfferWithAvailability]]] = [] for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)): if isinstance(result, BackendError): logger.warning( @@ -491,9 +491,7 @@ def get_filtered_offers_with_backends( ) continue offers_by_backend.append(get_filtered_offers_with_backends(backend, result)) - # Merge preserving order for every backend. - offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) - return offers + return merge_offer_iterables(*offers_by_backend) def check_backend_type_available(backend_type: BackendType): diff --git a/src/dstack/_internal/server/services/offers.py b/src/dstack/_internal/server/services/offers.py index 3ac8b7ed60..569bc07240 100644 --- a/src/dstack/_internal/server/services/offers.py +++ b/src/dstack/_internal/server/services/offers.py @@ -1,6 +1,7 @@ +import heapq import itertools from collections.abc import Container, Iterable, Iterator -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, TypeVar, Union import gpuhunt @@ -116,6 +117,21 @@ async def get_offers_by_requirements( return sorted(offers, key=lambda i: not i[1].availability.is_available()) +T = TypeVar("T") + + +def merge_offer_iterables( + *iterables: Iterable[tuple[T, InstanceOfferWithAvailability]], +) -> Iterable[tuple[T, InstanceOfferWithAvailability]]: + """ + Merge offers from different sources (e.g., different backends, different fleets). + + Some backends produce offers that are not sorted by price (e.g., `vastai` sorts by pod score). + That backend-specific order is preserved. + """ + return heapq.merge(*iterables, key=lambda i: i[1].price) + + def is_divisible_into_blocks( cpu_count: int, gpu_count: int, blocks: Union[int, Literal["auto"]] ) -> tuple[bool, int]: diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 1fcf3e7bd4..b49c9f8a3b 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -52,7 +52,10 @@ is_multinode_job, remove_job_spec_sensitive_info, ) -from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.server.services.offers import ( + get_offers_by_requirements, + merge_offer_iterables, +) from dstack._internal.server.services.requirements.combine import ( combine_fleet_and_run_profiles, combine_fleet_and_run_requirements, @@ -711,11 +714,10 @@ async def get_backend_offers_in_run_candidate_fleets( run_model=None, run_spec=run_spec, ) - deduplicated_backend_offers: dict[ - Hashable, - tuple[Backend, InstanceOfferWithAvailability], - ] = {} + seen_offer_identities = set() + offers: list[tuple[Backend, InstanceOfferWithAvailability]] = [] for candidate_fleet_model in candidate_fleet_models: + offers_from_fleet = [] for backend, offer in await _get_backend_offers_in_fleet( project=project, fleet_model=candidate_fleet_model, @@ -724,13 +726,12 @@ async def get_backend_offers_in_run_candidate_fleets( volumes=volumes, max_offers=max_offers_per_fleet, ): - deduplicated_backend_offers.setdefault( - _get_backend_offer_identity(offer), - (backend, offer), - ) - backend_offers = list(deduplicated_backend_offers.values()) - backend_offers.sort(key=lambda offer: offer[1].price) - return backend_offers + offer_identity = _get_backend_offer_identity(offer) + if offer_identity not in seen_offer_identities: + offers_from_fleet.append((backend, offer)) + seen_offer_identities.add(offer_identity) + offers = list(merge_offer_iterables(offers, offers_from_fleet)) + return offers async def _get_offers_in_run_candidate_fleets( diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 249780fcd8..6c8b7233f6 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -751,11 +751,13 @@ def get_fleet_configuration( name: str = "test-fleet", nodes: FleetNodesSpec = FleetNodesSpec(min=1, target=1, max=1), placement: Optional[InstanceGroupPlacement] = None, + backends: Optional[list[BackendType]] = None, ) -> FleetConfiguration: return FleetConfiguration( name=name, nodes=nodes, placement=placement, + backends=backends, ) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index e13e20853e..0ba685ea54 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -21,6 +21,7 @@ ScalingSpec, ServiceConfiguration, TaskConfiguration, + parse_run_configuration, ) from dstack._internal.core.models.fleets import FleetNodesSpec from dstack._internal.core.models.gateways import GatewayStatus @@ -66,6 +67,7 @@ create_run, create_user, get_auth_headers, + get_fleet_configuration, get_fleet_spec, get_instance_offer_with_availability, get_job_provisioning_data, @@ -1916,6 +1918,152 @@ async def test_returns_no_offers_if_imported_fleet_specified_without_project_pre assert response_json["project_name"] == "importer" assert len(response_json["job_plans"][0]["offers"]) == 0 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "configuration", + [ + pytest.param({"type": "dev-environment"}, id="regular-configuration"), + pytest.param( + {"type": "task", "commands": [":"], "image": "scratch"}, + id="special-configuration-used-by-dstack-offer-cli-command", + ), + pytest.param( + {"type": "task", "commands": [":"], "image": "scratch", "fleets": ["test-fleet"]}, + id="special-configuration-used-by-dstack-offer-cli-command-with-fleets", # --fleet + ), + ], + ) + async def test_preserves_backend_specific_offer_order( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + configuration: dict, + ) -> None: + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, + project=project, + user=user, + project_role=ProjectRole.USER, + ) + repo = await create_repo(session=session, project_id=project.id) + await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(conf=get_fleet_configuration(name="test-fleet")), + ) + + run_spec = get_run_spec( + repo_id=repo.name, configuration=parse_run_configuration(configuration) + ) + body = {"run_spec": run_spec.dict()} + + backend_mock_aws = Mock() + backend_mock_aws.TYPE = BackendType.AWS + backend_mock_aws.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0), + get_instance_offer_with_availability(backend=BackendType.AWS, price=4.0), + ] + backend_mock_vastai = Mock() + backend_mock_vastai.TYPE = BackendType.VASTAI + backend_mock_vastai.compute.return_value.get_offers.return_value = [ + # not ordered by price - custom order should be preserved + get_instance_offer_with_availability(backend=BackendType.VASTAI, price=3.0), + get_instance_offer_with_availability(backend=BackendType.VASTAI, price=2.0), + ] + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock_aws, backend_mock_vastai] + response = await client.post( + f"/api/project/{project.name}/runs/get_plan", + headers=get_auth_headers(user.token), + json=body, + ) + + assert response.status_code == 200, response.json() + offers = [(o["backend"], o["price"]) for o in response.json()["job_plans"][0]["offers"]] + expected_offers = [ + (BackendType.AWS.value, 1.0), + (BackendType.VASTAI.value, 3.0), + (BackendType.VASTAI.value, 2.0), + (BackendType.AWS.value, 4.0), + ] + assert offers == expected_offers + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_offer_cli_preserves_backend_specific_offer_order_across_fleets( + self, test_db, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, + project=project, + user=user, + project_role=ProjectRole.USER, + ) + repo = await create_repo(session=session, project_id=project.id) + await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration(name="fleet-aws", backends=[BackendType.AWS]) + ), + ) + await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration(name="fleet-vastai", backends=[BackendType.VASTAI]) + ), + ) + + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + commands=[":"], + image="scratch", + fleets=["fleet-aws", "fleet-vastai"], + ), + ) + body = {"run_spec": run_spec.dict()} + + backend_mock_aws = Mock() + backend_mock_aws.TYPE = BackendType.AWS + backend_mock_aws.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0), + get_instance_offer_with_availability(backend=BackendType.AWS, price=4.0), + ] + backend_mock_vastai = Mock() + backend_mock_vastai.TYPE = BackendType.VASTAI + backend_mock_vastai.compute.return_value.get_offers.return_value = [ + # not ordered by price - custom order should be preserved + get_instance_offer_with_availability(backend=BackendType.VASTAI, price=3.0), + get_instance_offer_with_availability(backend=BackendType.VASTAI, price=2.0), + ] + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock_aws, backend_mock_vastai] + response = await client.post( + f"/api/project/{project.name}/runs/get_plan", + headers=get_auth_headers(user.token), + json=body, + ) + + assert response.status_code == 200, response.json() + offers = [(o["backend"], o["price"]) for o in response.json()["job_plans"][0]["offers"]] + expected_offers = [ + (BackendType.AWS.value, 1.0), + (BackendType.VASTAI.value, 3.0), + (BackendType.VASTAI.value, 2.0), + (BackendType.AWS.value, 4.0), + ] + assert offers == expected_offers + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_offer_cli_returns_offers_from_all_specified_fleets(