Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package ssh

import (
"fmt"
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/experimental/ssh/internal/setup"
"github.com/databricks/cli/libs/cmdctx"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -57,17 +55,6 @@ an SSH host configuration to your SSH config file.
Profile: wsClient.Config.Profile,
AutoApprove: autoApprove,
}
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
AutoStartCluster: setupOpts.AutoStartCluster,
ShutdownDelay: setupOpts.ShutdownDelay,
Profile: setupOpts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}
setupOpts.ProxyCommand = proxyCommand
return setup.Setup(ctx, wsClient, setupOpts)
}

Expand Down
28 changes: 22 additions & 6 deletions experimental/ssh/internal/setup/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"time"

sshclient "github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/experimental/ssh/internal/keys"
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
"github.com/databricks/cli/libs/cmdio"
Expand All @@ -28,8 +29,6 @@ type SetupOptions struct {
SSHKeysDir string
// Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand
Profile string
// Proxy command to use for the SSH connection
ProxyCommand string
// Skip confirmation prompts (e.g. recreate existing host config without asking)
AutoApprove bool
}
Expand All @@ -45,17 +44,20 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie
return nil
}

func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error) {
func generateHostConfig(ctx context.Context, opts SetupOptions, proxyCommand string) (string, error) {
identityFilePath, err := keys.GetLocalSSHKeyPath(ctx, opts.ClusterID, opts.SSHKeysDir)
if err != nil {
return "", fmt.Errorf("failed to get local keys folder: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand)
hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, proxyCommand)
return hostConfig, nil
}

func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
// clusterSelectionPrompt is a package-level var so tests can replace it with a mock.
var clusterSelectionPrompt = defaultClusterSelectionPrompt

func defaultClusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
sp := cmdio.NewSpinner(ctx)
sp.Update("Loading clusters.")
clusters, err := client.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{
Expand Down Expand Up @@ -92,6 +94,20 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
return err
}

// Build the ProxyCommand after the cluster ID is resolved. When the user
// omits --cluster, the ID is only known after the interactive picker above,
// so building it earlier would serialize an empty --cluster= flag.
clientOpts := sshclient.ClientOptions{
ClusterID: opts.ClusterID,
AutoStartCluster: opts.AutoStartCluster,
ShutdownDelay: opts.ShutdownDelay,
Profile: opts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}

configPath, err := sshconfig.GetMainConfigPathOrDefault(ctx, opts.SSHConfigPath)
if err != nil {
return err
Expand All @@ -102,7 +118,7 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
return err
}

hostConfig, err := generateHostConfig(ctx, opts)
hostConfig, err := generateHostConfig(ctx, opts, proxyCommand)
if err != nil {
return err
}
Expand Down
86 changes: 51 additions & 35 deletions experimental/ssh/internal/setup/setup_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package setup

import (
"context"
"errors"
"fmt"
"os"
Expand All @@ -10,6 +11,7 @@ import (

"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -134,10 +136,9 @@ func TestGenerateHostConfig_Valid(t *testing.T) {
SSHKeysDir: tmpDir,
ShutdownDelay: 30 * time.Second,
Profile: "test-profile",
ProxyCommand: proxyCommand,
}

result, err := generateHostConfig(t.Context(), opts)
result, err := generateHostConfig(t.Context(), opts, proxyCommand)
assert.NoError(t, err)

assert.Contains(t, result, "Host test-host")
Expand Down Expand Up @@ -169,10 +170,9 @@ func TestGenerateHostConfig_WithoutProfile(t *testing.T) {
SSHKeysDir: tmpDir,
ShutdownDelay: 30 * time.Second,
Profile: "",
ProxyCommand: proxyCommand,
}

result, err := generateHostConfig(t.Context(), opts)
result, err := generateHostConfig(t.Context(), opts, proxyCommand)
assert.NoError(t, err)

assert.NotContains(t, result, "--profile=")
Expand All @@ -193,7 +193,7 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) {
ShutdownDelay: 30 * time.Second,
}

result, err := generateHostConfig(t.Context(), opts)
result, err := generateHostConfig(t.Context(), opts, "")
assert.NoError(t, err)

// Check that quotes are properly escaped
Expand Down Expand Up @@ -225,17 +225,7 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
Profile: "test-profile",
}

clientOpts := client.ClientOptions{
ClusterID: opts.ClusterID,
AutoStartCluster: opts.AutoStartCluster,
ShutdownDelay: opts.ShutdownDelay,
Profile: opts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
require.NoError(t, err)
opts.ProxyCommand = proxyCommand

err = Setup(ctx, m.WorkspaceClient, opts)
err := Setup(ctx, m.WorkspaceClient, opts)
assert.NoError(t, err)

// Check that main config has Include directive
Expand Down Expand Up @@ -285,15 +275,7 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) {
AutoApprove: true,
}

clientOpts := client.ClientOptions{
ClusterID: opts.ClusterID,
ShutdownDelay: opts.ShutdownDelay,
}
proxyCommand, err := clientOpts.ToProxyCommand()
require.NoError(t, err)
opts.ProxyCommand = proxyCommand

err = Setup(ctx, m.WorkspaceClient, opts)
err := Setup(ctx, m.WorkspaceClient, opts)
assert.NoError(t, err)

// Host config should be recreated (no longer contains the stale User).
Expand All @@ -304,6 +286,50 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) {
assert.Contains(t, s, "--cluster=cluster-123")
}

func TestSetup_PromptsForClusterWhenNotProvided(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)

configPath := filepath.Join(tmpDir, "ssh_config")

// Replace the cluster picker with a stub returning a fixed ID. This lets the
// test exercise the empty-ClusterID path of Setup without driving promptui.
origPrompt := clusterSelectionPrompt
t.Cleanup(func() { clusterSelectionPrompt = origPrompt })
promptCalled := false
clusterSelectionPrompt = func(_ context.Context, _ *databricks.WorkspaceClient) (string, error) {
promptCalled = true
return "picked-cluster", nil
}

m := mocks.NewMockWorkspaceClient(t)
clustersAPI := m.GetMockClustersAPI()
clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "picked-cluster"}).Return(&compute.ClusterDetails{
DataSecurityMode: compute.DataSecurityModeSingleUser,
}, nil)

opts := SetupOptions{
HostName: "test-host",
SSHConfigPath: configPath,
SSHKeysDir: tmpDir,
ShutdownDelay: 30 * time.Second,
}

err := Setup(ctx, m.WorkspaceClient, opts)
require.NoError(t, err)
assert.True(t, promptCalled, "cluster picker should run when ClusterID is empty")

// The picked ID must be serialized into the ProxyCommand's --cluster= flag.
hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host")
hostContent, err := os.ReadFile(hostConfigPath)
require.NoError(t, err)
hostConfigStr := string(hostContent)
assert.Contains(t, hostConfigStr, "--cluster=picked-cluster")
assert.NotContains(t, hostConfigStr, "--cluster= ")
}

func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
tmpDir := t.TempDir()
Expand Down Expand Up @@ -332,16 +358,6 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
ShutdownDelay: 60 * time.Second,
}

clientOpts := client.ClientOptions{
ClusterID: opts.ClusterID,
AutoStartCluster: opts.AutoStartCluster,
ShutdownDelay: opts.ShutdownDelay,
Profile: opts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
require.NoError(t, err)
opts.ProxyCommand = proxyCommand

err = Setup(ctx, m.WorkspaceClient, opts)
assert.NoError(t, err)

Expand Down
Loading