Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import logging
from datetime import timedelta
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Optional

import gpuhunt
import requests
Expand Down Expand Up @@ -86,8 +86,10 @@
get_instance_configuration,
get_instance_profile,
get_instance_provisioning_data,
get_instance_remote_connection_info,
get_instance_requirements,
get_instance_ssh_private_keys,
is_ssh_instance,
remove_dangling_tasks_from_instance,
switch_instance_status,
)
Expand Down Expand Up @@ -244,7 +246,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
instance = res.unique().scalar_one()

if instance.status == InstanceStatus.PENDING:
if instance.remote_connection_info is not None:
if is_ssh_instance(instance):
await _add_remote(session, instance)
else:
await _create_instance(
Expand Down Expand Up @@ -323,7 +325,8 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None:
return

try:
remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
remote_details = get_instance_remote_connection_info(instance)
assert remote_details is not None
# Prepare connection key
try:
pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dstack._internal.core.models.files import FileArchiveMapping
from dstack._internal.core.models.instances import (
InstanceStatus,
RemoteConnectionInfo,
SSHConnectionParams,
)
from dstack._internal.core.models.metrics import Metric
Expand Down Expand Up @@ -54,7 +53,10 @@
from dstack._internal.server.services import events, services
from dstack._internal.server.services import files as files_services
from dstack._internal.server.services import logs as logs_services
from dstack._internal.server.services.instances import get_instance_ssh_private_keys
from dstack._internal.server.services.instances import (
get_instance_remote_connection_info,
get_instance_ssh_private_keys,
)
from dstack._internal.server.services.jobs import (
find_job,
get_job_attached_volumes,
Expand Down Expand Up @@ -870,14 +872,11 @@ async def _maybe_register_replica(
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = common_utils.get_or_error(job_model.instance)
if instance.remote_connection_info is not None:
rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
instance.remote_connection_info
)
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
try:
await services.register_replica(
session,
Expand Down Expand Up @@ -1090,9 +1089,8 @@ def _submit_job_to_runner(
None if repo_credentials is None else repo_credentials.clone_url,
)
instance = job_model.instance
if instance is not None and instance.remote_connection_info is not None:
remote_info = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
instance_env = remote_info.env
if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None:
instance_env = rci.env
else:
instance_env = None

Expand Down
8 changes: 3 additions & 5 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable
from datetime import datetime
from functools import wraps
from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast
from typing import List, Literal, Optional, Tuple, TypeVar, Union

from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -32,7 +32,6 @@
InstanceOfferWithAvailability,
InstanceStatus,
InstanceTerminationReason,
RemoteConnectionInfo,
SSHConnectionParams,
SSHKey,
)
Expand Down Expand Up @@ -1106,9 +1105,8 @@ async def _check_ssh_hosts_not_yet_added(
# ignore instances belonging to the same fleet -- in-place update/recreate
if current_fleet_id is not None and instance.fleet_id == current_fleet_id:
continue
instance_conn_info = RemoteConnectionInfo.parse_raw(
cast(str, instance.remote_connection_info)
)
instance_conn_info = get_instance_remote_connection_info(instance)
assert instance_conn_info is not None
existing_hosts.add(instance_conn_info.host)

instances_already_in_fleet = []
Expand Down
8 changes: 6 additions & 2 deletions src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
return Requirements.__response__.parse_raw(instance_model.requirements)


def is_ssh_instance(instance_model: InstanceModel) -> bool:
return instance_model.remote_connection_info is not None


def get_instance_remote_connection_info(
instance_model: InstanceModel,
) -> Optional[RemoteConnectionInfo]:
Expand All @@ -299,11 +303,11 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O
Returns a pair of SSH private keys: host key and optional proxy jump key.
"""
host_private_key = instance_model.project.ssh_private_key
if instance_model.remote_connection_info is None:
rci = get_instance_remote_connection_info(instance_model)
if rci is None:
# Cloud instance
return host_private_key, None
# SSH instance
rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
if rci.ssh_proxy is None:
return host_private_key, None
if rci.ssh_proxy_keys is None:
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import (
JobProvisioningData,
JobSpec,
Expand All @@ -31,6 +31,7 @@
)
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
from dstack._internal.server.services.instances import get_instance_remote_connection_info
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
from dstack._internal.utils.common import get_or_error

Expand Down Expand Up @@ -97,11 +98,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
replica = Replica(
id=job.id.hex,
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/services/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import dstack._internal.server.services.jobs as jobs_services
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData
from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.instances import get_instance_remote_connection_info
from dstack._internal.utils.common import get_or_error
from dstack._internal.utils.path import FileContent

Expand Down Expand Up @@ -46,11 +47,10 @@ def container_ssh_tunnel(
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
if instance.remote_connection_info is not None:
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
ssh_proxies = []
if ssh_head_proxy is not None:
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)
Expand Down