diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 3cb53322a..184287b31 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -2,7 +2,7 @@ import datetime import logging from datetime import timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional import gpuhunt import requests @@ -86,8 +86,10 @@ get_instance_configuration, get_instance_profile, get_instance_provisioning_data, + get_instance_remote_connection_info, get_instance_requirements, get_instance_ssh_private_keys, + is_ssh_instance, remove_dangling_tasks_from_instance, switch_instance_status, ) @@ -244,7 +246,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): instance = res.unique().scalar_one() if instance.status == InstanceStatus.PENDING: - if instance.remote_connection_info is not None: + if is_ssh_instance(instance): await _add_remote(session, instance) else: await _create_instance( @@ -323,7 +325,8 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: return try: - remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info)) + remote_details = get_instance_remote_connection_info(instance) + assert remote_details is not None # Prepare connection key try: pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index bcb35a089..7275106ce 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -18,7 +18,6 @@ from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( InstanceStatus, - RemoteConnectionInfo, SSHConnectionParams, ) from dstack._internal.core.models.metrics import Metric @@ -54,7 +53,10 @@ from dstack._internal.server.services import events, services from dstack._internal.server.services import files as files_services from dstack._internal.server.services import logs as logs_services -from dstack._internal.server.services.instances import get_instance_ssh_private_keys +from dstack._internal.server.services.instances import ( + get_instance_remote_connection_info, + get_instance_ssh_private_keys, +) from dstack._internal.server.services.jobs import ( find_job, get_job_attached_volumes, @@ -870,14 +872,11 @@ async def _maybe_register_replica( ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None instance = common_utils.get_or_error(job_model.instance) - if instance.remote_connection_info is not None: - rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( - instance.remote_connection_info - ) - if rci.ssh_proxy is not None: - ssh_head_proxy = rci.ssh_proxy - ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys) - ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private + rci = get_instance_remote_connection_info(instance) + if rci is not None and rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys) + ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private try: await services.register_replica( session, @@ -1090,9 +1089,8 @@ def _submit_job_to_runner( None if repo_credentials is None else repo_credentials.clone_url, ) instance = job_model.instance - if instance is not None and instance.remote_connection_info is not None: - remote_info = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) - instance_env = remote_info.env + if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None: + instance_env = rci.env else: instance_env = None diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 9b877475f..8febbd126 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -2,7 +2,7 @@ from collections.abc import Callable from datetime import datetime from functools import wraps -from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast +from typing import List, Literal, Optional, Tuple, TypeVar, Union from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -32,7 +32,6 @@ InstanceOfferWithAvailability, InstanceStatus, InstanceTerminationReason, - RemoteConnectionInfo, SSHConnectionParams, SSHKey, ) @@ -1106,9 +1105,8 @@ async def _check_ssh_hosts_not_yet_added( # ignore instances belonging to the same fleet -- in-place update/recreate if current_fleet_id is not None and instance.fleet_id == current_fleet_id: continue - instance_conn_info = RemoteConnectionInfo.parse_raw( - cast(str, instance.remote_connection_info) - ) + instance_conn_info = get_instance_remote_connection_info(instance) + assert instance_conn_info is not None existing_hosts.add(instance_conn_info.host) instances_already_in_fleet = [] diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index c311df7db..8506ad273 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -286,6 +286,10 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements: return Requirements.__response__.parse_raw(instance_model.requirements) +def is_ssh_instance(instance_model: InstanceModel) -> bool: + return instance_model.remote_connection_info is not None + + def get_instance_remote_connection_info( instance_model: InstanceModel, ) -> Optional[RemoteConnectionInfo]: @@ -299,11 +303,11 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O Returns a pair of SSH private keys: host key and optional proxy jump key. """ host_private_key = instance_model.project.ssh_private_key - if instance_model.remote_connection_info is None: + rci = get_instance_remote_connection_info(instance_model) + if rci is None: # Cloud instance return host_private_key, None # SSH instance - rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info) if rci.ssh_proxy is None: return host_private_key, None if rci.ssh_proxy_keys is None: diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index f8c8d882c..7f1564fe6 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -9,7 +9,7 @@ from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams +from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import ( JobProvisioningData, JobSpec, @@ -31,6 +31,7 @@ ) from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.server.models import JobModel, ProjectModel, RunModel +from dstack._internal.server.services.instances import get_instance_remote_connection_info from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE from dstack._internal.utils.common import get_or_error @@ -97,11 +98,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None instance = get_or_error(job.instance) - if instance.remote_connection_info is not None: - rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) - if rci.ssh_proxy is not None: - ssh_head_proxy = rci.ssh_proxy - ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private + rci = get_instance_remote_connection_info(instance) + if rci is not None and rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data) replica = Replica( id=job.id.hex, diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index d1ba8ffc8..0fa7c189e 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -4,10 +4,11 @@ import dstack._internal.server.services.jobs as jobs_services from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams +from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel from dstack._internal.server.models import JobModel +from dstack._internal.server.services.instances import get_instance_remote_connection_info from dstack._internal.utils.common import get_or_error from dstack._internal.utils.path import FileContent @@ -46,11 +47,10 @@ def container_ssh_tunnel( ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None instance = get_or_error(job.instance) - if instance.remote_connection_info is not None: - rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) - if rci.ssh_proxy is not None: - ssh_head_proxy = rci.ssh_proxy - ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private + rci = get_instance_remote_connection_info(instance) + if rci is not None and rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private ssh_proxies = [] if ssh_head_proxy is not None: ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)