Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/workon.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string)

credEnvVars := creds.EnvVars()

ptdRoot := helpers.GetTargetsConfigPath()

// Start proxy if needed (non-fatal)
proxyFile := path.Join(internal.DataDir(), "proxy.json")
stopProxy, err := kube.StartProxy(cmd.Context(), t, proxyFile)
Expand Down Expand Up @@ -173,6 +175,7 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string)
for k, v := range credEnvVars {
shellCommand.Env = append(shellCommand.Env, k+"="+v)
}
shellCommand.Env = append(shellCommand.Env, "PTD_ROOT="+ptdRoot)
if kubeconfigPath != "" {
shellCommand.Env = append(shellCommand.Env, "KUBECONFIG="+kubeconfigPath)
}
Expand Down Expand Up @@ -210,6 +213,7 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string)
for k, v := range credEnvVars {
shellCommand.Env = append(shellCommand.Env, k+"="+v)
}
shellCommand.Env = append(shellCommand.Env, "PTD_ROOT="+ptdRoot)
if kubeconfigPath != "" {
shellCommand.Env = append(shellCommand.Env, "KUBECONFIG="+kubeconfigPath)
}
Expand Down
43 changes: 24 additions & 19 deletions lib/azure/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"os/exec"
"os/signal"
"regexp"
"time"

"github.com/posit-dev/ptd/lib/helpers"
Expand All @@ -22,6 +21,7 @@ type ProxySession struct {
tunnelCommand *exec.Cmd
socksCommand *exec.Cmd
localPort string
sshKeyPath string // temp file for bastion SSH key, cleaned up on Stop

runningProxy *proxy.RunningProxy
isReused bool // indicates if the session is reused from an existing running proxy
Expand Down Expand Up @@ -88,16 +88,10 @@ func (p *ProxySession) Start(ctx context.Context) error {
return err
}

bastionName, err := p.target.BastionName(ctx)

if err != nil {
slog.Error("Error getting bastion name", "error", err)
}

jumpBoxId, err := p.target.JumpBoxId(ctx)

bastionInfo, err := p.target.BastionInfo(ctx)
if err != nil {
slog.Error("Error getting jump box ID", "error", err)
slog.Error("Error getting bastion info", "error", err)
return fmt.Errorf("failed to get bastion info: %w", err)
}

// Determine which resource group to use for the bastion tunnel
Expand All @@ -114,25 +108,32 @@ func (p *ProxySession) Start(ctx context.Context) error {
return fmt.Errorf("Resource Group name is empty, cannot continue.")
}

// HACK: at the moment, the ssh key is written to a path and named based on the bastion name.
// This is a temporary workaround to remove the "-host" suffix from the bastion name, since that isn't in the key name
r := regexp.MustCompile(`-host.*`)
bastionSshKeyName := r.ReplaceAllString(bastionName, "")
// Write the SSH private key from Pulumi state to a temp file (cleaned up in Stop)
sshKeyFile, err := os.CreateTemp("", "ptd-bastion-ssh-*")
if err != nil {
return fmt.Errorf("failed to create temp file for SSH key: %w", err)
}
if _, err := sshKeyFile.WriteString(bastionInfo.SSHPrivateKey); err != nil {
sshKeyFile.Close()
os.Remove(sshKeyFile.Name())
return fmt.Errorf("failed to write SSH key: %w", err)
}
sshKeyFile.Close()
p.sshKeyPath = sshKeyFile.Name()

// build the command to start the bastion tunnel, this will connect jumpbox:22 to localhost:22001 (enabling SSH connection via separate command)
p.tunnelCommand = exec.CommandContext(
ctx,
p.azCliPath,
"network", "bastion", "tunnel",
"--name", bastionName,
"--name", bastionInfo.Name,
"--resource-group", resourceGroupName,
"--target-resource-id", jumpBoxId,
"--target-resource-id", bastionInfo.JumpBoxID,
"--resource-port", "22",
"--port", "22001",
)

// build the command to start the SOCKS proxy via SSH, using the jumpbox tunnel from above
// ssh -ND 1080 ptd-admin@localhost -p 22001 -i ~/.ssh/bas-ptd-madrigal01-production-bastion
p.socksCommand = exec.CommandContext(
ctx,
"ssh",
Expand All @@ -141,7 +142,7 @@ func (p *ProxySession) Start(ctx context.Context) error {
"-p", "22001",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", fmt.Sprintf("%s/.ssh/%s", os.Getenv("HOME"), bastionSshKeyName))
"-i", p.sshKeyPath)

// set the environment variables for the command
// add each az env var to command
Expand All @@ -150,7 +151,7 @@ func (p *ProxySession) Start(ctx context.Context) error {
p.socksCommand.Env = append(p.socksCommand.Env, fmt.Sprintf("%s=%s", k, v))
}

slog.Debug("Starting Azure bastion tunnel", "bastion_name", bastionName, "resource_group", resourceGroupName, "tunnel_port", "22001", "target_port", "22")
slog.Debug("Starting Azure bastion tunnel", "bastion_name", bastionInfo.Name, "resource_group", resourceGroupName, "tunnel_port", "22001", "target_port", "22")
if ctx.Value("verbose") != nil && ctx.Value("verbose").(bool) {
slog.Debug("Verbose turned on, attaching command output to stdout and stderr")
p.tunnelCommand.Stdout = os.Stdout
Expand Down Expand Up @@ -201,6 +202,10 @@ func (p *ProxySession) Start(ctx context.Context) error {
}

func (p *ProxySession) Stop() error {
if p.sshKeyPath != "" {
os.Remove(p.sshKeyPath)
}

if p.isReused {
slog.Debug("Proxy session was reused, not stopping", "target", p.target.Name(), "local_port", p.localPort)
return nil
Expand Down
68 changes: 27 additions & 41 deletions lib/azure/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,18 @@ func (t Target) fullPulumiEnvVars(ctx context.Context) (map[string]string, error
return creds.EnvVars(), nil
}

func (t Target) BastionName(ctx context.Context) (string, error) {
// BastionInfo holds the bastion connection details from the persistent stack.
type BastionInfo struct {
Name string
JumpBoxID string
SSHPrivateKey string
}

// BastionInfo retrieves bastion connection details from the persistent stack outputs.
func (t Target) BastionInfo(ctx context.Context) (*BastionInfo, error) {
envVars, err := t.fullPulumiEnvVars(ctx)
if err != nil {
return "", err
return nil, err
}

persistentStack, err := pulumi.NewPythonPulumiStack(
Expand All @@ -185,57 +193,35 @@ func (t Target) BastionName(ctx context.Context) (string, error) {
false,
)
if err != nil {
return "", err
return nil, err
}

persistentOutputs, err := persistentStack.Outputs(ctx)
outputs, err := persistentStack.Outputs(ctx)
if err != nil {
return "", err
}

if _, ok := persistentOutputs["bastion_name"]; !ok {
return "", fmt.Errorf("bastion_name output not found in persistent stack outputs")
return nil, err
}

bastionName := persistentOutputs["bastion_name"].Value.(string)
info := &BastionInfo{}

return bastionName, nil
}

func (t Target) JumpBoxId(ctx context.Context) (string, error) {
envVars, err := t.fullPulumiEnvVars(ctx)
if err != nil {
return "", err
}

persistentStack, err := pulumi.NewPythonPulumiStack(
ctx,
"azure",
"workload",
"persistent",
t.Name(),
t.Region(),
t.PulumiBackendUrl(),
t.PulumiSecretsProviderKey(),
envVars,
false,
)
if err != nil {
return "", err
if v, ok := outputs["bastion_name"]; ok {
info.Name = v.Value.(string)
} else {
return nil, fmt.Errorf("bastion_name output not found in persistent stack outputs")
}

persistentOutputs, err := persistentStack.Outputs(ctx)
if err != nil {
return "", err
if v, ok := outputs["bastion_jumpbox_id"]; ok {
info.JumpBoxID = v.Value.(string)
} else {
return nil, fmt.Errorf("bastion_jumpbox_id output not found in persistent stack outputs")
}

if _, ok := persistentOutputs["bastion_jumpbox_id"]; !ok {
return "", fmt.Errorf("bastion_jumpbox_id output not found in persistent stack outputs")
if v, ok := outputs["bastion_ssh_private_key"]; ok {
info.SSHPrivateKey = v.Value.(string)
} else {
return nil, fmt.Errorf("bastion_ssh_private_key output not found in persistent stack outputs")
}

jumpBoxId := persistentOutputs["bastion_jumpbox_id"].Value.(string)

return jumpBoxId, nil
return info, nil
}

// HashName returns an obfuscated name for the target that can be used as a unique identifier.
Expand Down
14 changes: 0 additions & 14 deletions python-pulumi/src/ptd/pulumi_resources/azure_bastion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pulumi
import pulumi_tls as tls
from pulumi_azure_native import compute, network
from pulumi_command import local


class AzureBastion(pulumi.ComponentResource):
Expand Down Expand Up @@ -50,19 +49,6 @@ def __init__(
algorithm="ED25519",
)

# write the private key to a file on the local machine
# this needs to be repeated by any engineer who wants to access the jumpbox
local.run_output(
command=pulumi.Output.format(
"FILE=~/.ssh/{1}; "
'if [ ! -f "$FILE" ]; then '
'echo \'{0}\' > "$FILE" && chmod 600 "$FILE"; '
'else echo "File $FILE already exists, skipping."; fi',
self.jumpbox_ssh_key.private_key_openssh,
name,
),
)

# Create a Public IP for Bastion
self.public_ip = network.PublicIPAddress(
f"{name}-pip",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
"app_gateway_subnet_id": self.app_gateway_subnet.id,
"bastion_name": self.bastion.bastion_host.name,
"bastion_jumpbox_id": self.bastion.jumpbox_host.id,
"bastion_ssh_private_key": self.bastion.jumpbox_ssh_key.private_key_openssh,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be explicitly marked as a secret if we're going to put it in the output?

This probably makes more sense to fetch from azure secret vault, right?

"mimir_password": self.mimir_password.result,
"private_subnet_name": self.private_subnet.name,
"private_subnet_cidr": self.private_subnet.address_prefix,
Expand Down
Loading