-
Notifications
You must be signed in to change notification settings - Fork 213
Add pd disaggregated inference #3558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Test2 Internal IP Test Add worker with internal_ip Check status and register Add Status Ready Log Add Prefill-Decode Add PD to dstack Test register worker without poll Add router config in service config Update remove worker Clean Up router code Clean Up Further Cleanup
| ) | ||
| ), | ||
| ] = None | ||
| router_config: Annotated[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) router_config -> router for brevity and consistency with gateway configurations?
| Optional[AnyRouterConfig], | ||
| Field( | ||
| description=( | ||
| "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) The supported properties (routing policy and pd_disaggregation) should already be visible in the AnyRouterConfig reference, duplicating them here may lead to inconsistencies when adding or removing properties in the future
| ) | ||
| ), | ||
| ] = None | ||
| router_config: Annotated[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any new properties should be excluded from client requests for compatibility with older servers.
See get_run_spec_excludes.
| description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" | ||
| ), | ||
| ] = "cache_aware" | ||
| pd_disaggregation: Annotated[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any new properties should be excluded from client requests for compatibility with older servers.
See get_run_spec_excludes and _get_gateway_configuration_excludes.
| description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" | ||
| ), | ||
| ] = "cache_aware" | ||
| pd_disaggregation: Annotated[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adds pd_disaggregation to both service configurations and gateway configurations. I'd advocate for adding it to service configurations only.
-
Whether or not the service is configured to use PD disaggregation is clearly a service property, because it depends on the replica groups configuration. I don't think many users would want to configure that property at the gateway level, making assumptions about what services are going to run on that gateway in the future.
-
Having two places for the same property complicates the interface — you'd have to explain in the docs how these places are related to each other, when and how one setting overrides the other, etc.
-
Having the property at the gateway level can potentially complicate further development — that way, you can only tell whether a service is using PD disaggregation if the service is already registered, and you need to fetch the GatewayModel object from the database to do so. For example, this would complicate adding the default probe for services with PD disaggregation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, it seems to be possible to run a service that defines router_config.type == "sglang" on a non-SGLang gateway or even without a gateway. I assume this either won't work (if the service requires PD disaggregation) or will just lead to unexpected behavior, since router_config will be ignored.
I'd suggest to enforce that services with router_config.type == "sglang" only run on SGLang gateways by raising a relevant exception in _register_service_in_server and _register_service_in_gateway.
|
|
||
| async def add_worker_to_router( | ||
| self, | ||
| url: str, | ||
| worker_type: str = "regular", | ||
| bootstrap_port: Optional[int] = None, | ||
| ) -> bool: | ||
| """Add a worker to the router. | ||
|
|
||
| Args: | ||
| url: Worker URL (e.g. http://10.0.5.134:8000). | ||
| worker_type: Type of worker ("regular", "prefill", or "decode"). | ||
| bootstrap_port: Bootstrap port for prefill workers (optional). | ||
|
|
||
| Returns: | ||
| True if the worker was accepted, False otherwise. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| async def register_worker(self, url: str) -> bool: | ||
| """Register worker with one attempt (no polling). Returns True if ready and added.""" | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) These methods don't need to be in the Router base class, they are only ever called by SglangRouter.update_replicas and never by external callers. Consider removing them from Router and making private in SglangRouter (prefixed with _)
| self._remove_worker_from_router(replica_url) | ||
|
|
||
| def update_replicas(self, replica_urls: List[str]) -> None: | ||
| async def update_replicas(self, replica_urls: List[str]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method is now async, but it still calls _get_router_workers and _remove_worker_from_router, which are synchronous and perform blocking I/O using httpx.Client. This can block the gateway's event loop for a long time, during which the gateway will be inoperable.
I'd suggest to revert the method back to synchronous and call it outside of the event loop using run_async, as was done before.
Or make all SglangRouter methods async and ensure there are no blocking calls (no httpx.Client, subprocess.Popen, etc). But that'd be too many unrelated changes, so I'd suggest to stick to the synchronous interface in this PR.
| current_workers = self._get_router_workers() | ||
| worker_id = None | ||
| for worker in current_workers: | ||
| url = worker.get("url") | ||
| if url and isinstance(url, str) and url == worker_url: | ||
| worker_id = worker.get("id") | ||
| if worker_id and isinstance(worker_id, str): | ||
| break | ||
| if not worker_id: | ||
| logger.exception("No worker id found for url %s", worker_url) | ||
| return False | ||
| with httpx.Client(timeout=5.0) as client: | ||
| response = client.delete( | ||
| f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" | ||
| f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this change for?
| if worker_id and isinstance(worker_id, str): | ||
| break | ||
| if not worker_id: | ||
| logger.exception("No worker id found for url %s", worker_url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) logger.exception should only be called from an exception handler. You can use logger.error
| replica_urls = [ | ||
| f"http://{replica.internal_ip}:{replica.port}" | ||
| for replica in conf.replicas | ||
| if replica.internal_ip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ignores replicas that don't have an internal IP (most backends don't set the internal IP). I assume such replicas will be shown as running and registered in dstack, but the service won't work, which will be difficult to troubleshoot.
Raise ProxyError if pd_disaggregation is enabled and some replicas don't have an internal IP?
| server_info_url = f"{url}/server_info" | ||
| try: | ||
| async with httpx.AsyncClient(timeout=10) as client: | ||
| resp = await client.get(server_info_url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Avoid this request and assume worker_type = "regular" if pd_disaggregation is disabled?
| if domain in self._domain_to_worker_urls: | ||
| worker_urls = self._domain_to_worker_urls[domain] | ||
| await run_async(router.remove_replicas, worker_urls) | ||
| self._discard_ports(worker_urls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No ports are allocated for services with pd_disaggregation, but _discard_ports is still called unconditionally. This can potentially deallocate ports that are actually allocated to some other unrelated service
| replica_conns, replica_failures = await get_or_add_replica_connections( | ||
| service, repo, service_conn_pool | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) This opens an SSH tunnel unconditionally, but for services with pd_disaggregation the gateway doesn't actually need an SSH tunnel, because it communicates to replicas directly over internal IP
Testing Steps
Create (CPU node) in K8s cluster
Create gateway in the CPU node using below config
Create GPU-node with 3 instances (1 Prefill, 1 Decode and 1 for testing scaling) in the same K8s cluster where gateway node exists.
Note: See design doc for details on why the gateway and workers are required to be on the same network.
Apply below prefill-decode service configuration
rps>=3prefill replica scales to 2.Note: For testing you need to assign wheel to
https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl