diff --git a/extensions/vscode/src/api/types/credentials.ts b/extensions/vscode/src/api/types/credentials.ts index 107507682f..254e2c8e6f 100644 --- a/extensions/vscode/src/api/types/credentials.ts +++ b/extensions/vscode/src/api/types/credentials.ts @@ -27,5 +27,8 @@ export type TestResult = { user: CredentialUser | null; url: string | null; serverType: ServerType | null; + // When true, Snowflake connections are configured on the system and Token + // Authentication should be hidden (it won't work from within Snowflake). + hasSnowflakeConnections: boolean; error: AgentError | null; }; diff --git a/extensions/vscode/src/multiStepInputs/newConnectCredential.test.ts b/extensions/vscode/src/multiStepInputs/newConnectCredential.test.ts index a4502cc1f9..6175c525cd 100644 --- a/extensions/vscode/src/multiStepInputs/newConnectCredential.test.ts +++ b/extensions/vscode/src/multiStepInputs/newConnectCredential.test.ts @@ -2,7 +2,12 @@ import { describe, expect, test, vi, beforeEach, afterEach } from "vitest"; import { ServerType } from "src/api/types/contentRecords"; -import { newConnectCredential } from "./newConnectCredential"; +import { + newConnectCredential, + getAuthMethod, + AuthMethod, + AuthMethodName, +} from "./newConnectCredential"; // Mock the MultiStepInput module vi.mock("./multiStepHelper", () => { @@ -200,3 +205,35 @@ describe("newConnectCredential API calls", () => { ); }); }); + +describe("getAuthMethod", () => { + test("returns API_KEY for AuthMethodName.API_KEY", () => { + expect(getAuthMethod(AuthMethodName.API_KEY)).toBe(AuthMethod.API_KEY); + }); + + test("returns TOKEN for AuthMethodName.TOKEN", () => { + expect(getAuthMethod(AuthMethodName.TOKEN)).toBe(AuthMethod.TOKEN); + }); + + test("returns SNOWFLAKE_CONN for AuthMethodName.SNOWFLAKE_CONN", () => { + expect(getAuthMethod(AuthMethodName.SNOWFLAKE_CONN)).toBe( + AuthMethod.SNOWFLAKE_CONN, + ); + }); +}); + +describe("AuthMethod enum", () => { + test("has correct values", () => { + expect(AuthMethod.API_KEY).toBe("apiKey"); + expect(AuthMethod.TOKEN).toBe("token"); + expect(AuthMethod.SNOWFLAKE_CONN).toBe("snowflakeConnection"); + }); +}); + +describe("AuthMethodName enum", () => { + test("has correct display names", () => { + expect(AuthMethodName.API_KEY).toBe("API Key"); + expect(AuthMethodName.TOKEN).toBe("Token Authentication"); + expect(AuthMethodName.SNOWFLAKE_CONN).toBe("Snowflake Connection"); + }); +}); diff --git a/extensions/vscode/src/multiStepInputs/newConnectCredential.ts b/extensions/vscode/src/multiStepInputs/newConnectCredential.ts index d2d201ddb0..4e6a9d0c34 100644 --- a/extensions/vscode/src/multiStepInputs/newConnectCredential.ts +++ b/extensions/vscode/src/multiStepInputs/newConnectCredential.ts @@ -41,22 +41,26 @@ import { TokenAuthResult, } from "src/auth/ConnectAuthTokenActivator"; -enum AuthMethod { +export enum AuthMethod { API_KEY = "apiKey", TOKEN = "token", + SNOWFLAKE_CONN = "snowflakeConnection", } -enum AuthMethodName { +export enum AuthMethodName { API_KEY = "API Key", TOKEN = "Token Authentication", + SNOWFLAKE_CONN = "Snowflake Connection", } -const getAuthMethod = (authMethodName: AuthMethodName) => { +export const getAuthMethod = (authMethodName: AuthMethodName) => { switch (authMethodName) { case AuthMethodName.API_KEY: return AuthMethod.API_KEY; case AuthMethodName.TOKEN: return AuthMethod.TOKEN; + case AuthMethodName.SNOWFLAKE_CONN: + return AuthMethod.SNOWFLAKE_CONN; } }; @@ -76,6 +80,9 @@ export async function newConnectCredential( let serverType: ServerType = ServerType.CONNECT; const productName: ProductName = ProductName.CONNECT; let authMethod: AuthMethod = AuthMethod.TOKEN; + // When true, Snowflake connections are available on the system (we're inside Snowflake) + // and Token Authentication should be hidden (browser can't reach internal URLs). + let hasSnowflakeConnections: boolean = false; enum step { INPUT_SERVER_URL = "inputServerUrl", @@ -106,6 +113,10 @@ export async function newConnectCredential( return authMethod === AuthMethod.API_KEY; }; + const isSnowflakeConn = (authMethod: AuthMethod) => { + return authMethod === AuthMethod.SNOWFLAKE_CONN; + }; + const isValidTokenAuth = () => { // for token authentication, require token and privateKey return ( @@ -117,17 +128,17 @@ export async function newConnectCredential( }; const isValidApiKeyAuth = () => { - // for API key authentication, require apiKey - return ( - isConnect(serverType) && - isApiKey(authMethod) && - isString(state.data.apiKey) - ); + // for API key authentication, require apiKey (works for both Connect and Snowflake) + return isApiKey(authMethod) && isString(state.data.apiKey); }; const isValidSnowflakeAuth = () => { - // for Snowflake, require snowflakeConnection - return isSnowflake(serverType) && isString(state.data.snowflakeConnection); + // for Snowflake Connection authentication, require snowflakeConnection + return ( + isSnowflake(serverType) && + isSnowflakeConn(authMethod) && + isString(state.data.snowflakeConnection) + ); }; // *************************************************************** @@ -272,6 +283,8 @@ export async function newConnectCredential( // serverType will be overwritten if it is snowflake serverType = testResult.data.serverType; } + // Capture whether we're inside a Snowflake environment + hasSnowflakeConnections = testResult.data.hasSnowflakeConnections; } catch (e) { return Promise.resolve({ message: `Error: Invalid URL (unable to validate connectivity with Server URL - ${getMessageFromError(e)}).`, @@ -286,14 +299,6 @@ export async function newConnectCredential( state.data.url = formatURL(resp.trim()); - if (isSnowflake(serverType)) { - return { - name: step.INPUT_SNOWFLAKE_CONN, - step: (input: MultiStepInput) => - steps[step.INPUT_SNOWFLAKE_CONN](input, state), - }; - } - return { name: step.INPUT_AUTH_METHOD, step: (input: MultiStepInput) => @@ -302,19 +307,39 @@ export async function newConnectCredential( } // *************************************************************** - // Step: Select authentication method (Connect only) + // Step: Select authentication method + // For Connect (not in Snowflake): Token Authentication (Recommended) or API Key + // For Snowflake (detected by URL or by hasSnowflakeConnections): Snowflake Connection or API Key (no Token Auth) // *************************************************************** async function inputAuthMethod(input: MultiStepInput, state: MultiStepState) { - const authMethods = [ - { - label: AuthMethodName.TOKEN, - description: "Recommended - one click connection", - }, - { - label: AuthMethodName.API_KEY, - description: "Manually enter an API key", - }, - ]; + // Hide Token Auth when: + // - URL is detected as Snowflake (isSnowflake(serverType)) + // - OR Snowflake connections are available on the system (hasSnowflakeConnections) + // This handles the case where user enters internal URL like https://connect/ + const shouldHideTokenAuth = + isSnowflake(serverType) || hasSnowflakeConnections; + + const authMethods = shouldHideTokenAuth + ? [ + { + label: AuthMethodName.SNOWFLAKE_CONN, + description: "Use Snowflake connection for authentication", + }, + { + label: AuthMethodName.API_KEY, + description: "Manually enter an API key", + }, + ] + : [ + { + label: AuthMethodName.TOKEN, + description: "Recommended - one click connection", + }, + { + label: AuthMethodName.API_KEY, + description: "Manually enter an API key", + }, + ]; const pick = await input.showQuickPick({ title: state.title, @@ -322,7 +347,7 @@ export async function newConnectCredential( totalSteps: 0, placeholder: "Select authentication method", items: authMethods, - activeItem: authMethods[0], // Token authentication is default + activeItem: authMethods[0], buttons: [], shouldResume: () => Promise.resolve(false), ignoreFocusOut: true, @@ -330,6 +355,14 @@ export async function newConnectCredential( authMethod = getAuthMethod(pick.label as AuthMethodName); + if (isSnowflakeConn(authMethod)) { + return { + name: step.INPUT_SNOWFLAKE_CONN, + step: (input: MultiStepInput) => + steps[step.INPUT_SNOWFLAKE_CONN](input, state), + }; + } + if (isApiKey(authMethod)) { return { name: step.INPUT_API_KEY, diff --git a/extensions/vscode/src/utils/multiStepHelpers.test.ts b/extensions/vscode/src/utils/multiStepHelpers.test.ts new file mode 100644 index 0000000000..1f3b371bce --- /dev/null +++ b/extensions/vscode/src/utils/multiStepHelpers.test.ts @@ -0,0 +1,127 @@ +// Copyright (C) 2026 by Posit Software, PBC. + +import { describe, expect, test } from "vitest"; +import { + isConnect, + isConnectCloud, + isSnowflake, + isConnectProduct, + isConnectCloudProduct, + getProductType, + getProductName, + getServerType, +} from "./multiStepHelpers"; +import { + ServerType, + ProductType, + ProductName, +} from "../api/types/contentRecords"; + +describe("Server Type helpers", () => { + describe("isConnect", () => { + test("returns true for ServerType.CONNECT", () => { + expect(isConnect(ServerType.CONNECT)).toBe(true); + }); + + test("returns false for ServerType.SNOWFLAKE", () => { + expect(isConnect(ServerType.SNOWFLAKE)).toBe(false); + }); + + test("returns false for ServerType.CONNECT_CLOUD", () => { + expect(isConnect(ServerType.CONNECT_CLOUD)).toBe(false); + }); + }); + + describe("isSnowflake", () => { + test("returns true for ServerType.SNOWFLAKE", () => { + expect(isSnowflake(ServerType.SNOWFLAKE)).toBe(true); + }); + + test("returns false for ServerType.CONNECT", () => { + expect(isSnowflake(ServerType.CONNECT)).toBe(false); + }); + + test("returns false for ServerType.CONNECT_CLOUD", () => { + expect(isSnowflake(ServerType.CONNECT_CLOUD)).toBe(false); + }); + }); + + describe("isConnectCloud", () => { + test("returns true for ServerType.CONNECT_CLOUD", () => { + expect(isConnectCloud(ServerType.CONNECT_CLOUD)).toBe(true); + }); + + test("returns false for ServerType.CONNECT", () => { + expect(isConnectCloud(ServerType.CONNECT)).toBe(false); + }); + + test("returns false for ServerType.SNOWFLAKE", () => { + expect(isConnectCloud(ServerType.SNOWFLAKE)).toBe(false); + }); + }); +}); + +describe("Product Type helpers", () => { + describe("isConnectProduct", () => { + test("returns true for ProductType.CONNECT", () => { + expect(isConnectProduct(ProductType.CONNECT)).toBe(true); + }); + + test("returns false for ProductType.CONNECT_CLOUD", () => { + expect(isConnectProduct(ProductType.CONNECT_CLOUD)).toBe(false); + }); + }); + + describe("isConnectCloudProduct", () => { + test("returns true for ProductType.CONNECT_CLOUD", () => { + expect(isConnectCloudProduct(ProductType.CONNECT_CLOUD)).toBe(true); + }); + + test("returns false for ProductType.CONNECT", () => { + expect(isConnectCloudProduct(ProductType.CONNECT)).toBe(false); + }); + }); +}); + +describe("Type conversion helpers", () => { + describe("getProductType", () => { + test("returns ProductType.CONNECT for ServerType.CONNECT", () => { + expect(getProductType(ServerType.CONNECT)).toBe(ProductType.CONNECT); + }); + + test("returns ProductType.CONNECT for ServerType.SNOWFLAKE", () => { + // Snowflake is a Connect product (Connect running inside Snowflake) + expect(getProductType(ServerType.SNOWFLAKE)).toBe(ProductType.CONNECT); + }); + + test("returns ProductType.CONNECT_CLOUD for ServerType.CONNECT_CLOUD", () => { + expect(getProductType(ServerType.CONNECT_CLOUD)).toBe( + ProductType.CONNECT_CLOUD, + ); + }); + }); + + describe("getProductName", () => { + test("returns ProductName.CONNECT for ProductType.CONNECT", () => { + expect(getProductName(ProductType.CONNECT)).toBe(ProductName.CONNECT); + }); + + test("returns ProductName.CONNECT_CLOUD for ProductType.CONNECT_CLOUD", () => { + expect(getProductName(ProductType.CONNECT_CLOUD)).toBe( + ProductName.CONNECT_CLOUD, + ); + }); + }); + + describe("getServerType", () => { + test("returns ServerType.CONNECT for ProductName.CONNECT", () => { + expect(getServerType(ProductName.CONNECT)).toBe(ServerType.CONNECT); + }); + + test("returns ServerType.CONNECT_CLOUD for ProductName.CONNECT_CLOUD", () => { + expect(getServerType(ProductName.CONNECT_CLOUD)).toBe( + ServerType.CONNECT_CLOUD, + ); + }); + }); +}); diff --git a/internal/services/api/api_service.go b/internal/services/api/api_service.go index e4d510dd46..0b4c203337 100644 --- a/internal/services/api/api_service.go +++ b/internal/services/api/api_service.go @@ -109,7 +109,7 @@ func RouterHandlerFunc(base util.AbsolutePath, lister accounts.AccountList, log })).Methods(http.MethodDelete) // POST /api/test-credentials - r.Handle(ToPath("test-credentials"), PostTestCredentialsHandlerFunc(log)). + r.Handle(ToPath("test-credentials"), PostTestCredentialsHandlerFunc(log, snowflake.NewConnections())). Methods(http.MethodPost) // POST /api/connect/open-content diff --git a/internal/services/api/post_test_credentials.go b/internal/services/api/post_test_credentials.go index 369b1fa95b..0f45ff3049 100644 --- a/internal/services/api/post_test_credentials.go +++ b/internal/services/api/post_test_credentials.go @@ -4,13 +4,14 @@ package api import ( "encoding/json" - "github.com/posit-dev/publisher/internal/server_type" "net/http" "time" "github.com/posit-dev/publisher/internal/accounts" + "github.com/posit-dev/publisher/internal/api_client/auth/snowflake" "github.com/posit-dev/publisher/internal/clients/connect" "github.com/posit-dev/publisher/internal/logging" + "github.com/posit-dev/publisher/internal/server_type" "github.com/posit-dev/publisher/internal/types" "github.com/posit-dev/publisher/internal/util" ) @@ -28,12 +29,17 @@ type PostTestCredentialsResponseBody struct { ServerType server_type.ServerType `json:"serverType"` + // HasSnowflakeConnections indicates if Snowflake connections are configured + // on the system. When true, Token Authentication should be hidden since it + // won't work from within a Snowflake environment. + HasSnowflakeConnections bool `json:"hasSnowflakeConnections"` + Error *types.AgentError `json:"error"` } var connectClientFactory = connect.NewConnectClient -func PostTestCredentialsHandlerFunc(log logging.Logger) http.HandlerFunc { +func PostTestCredentialsHandlerFunc(log logging.Logger, connections snowflake.Connections) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { dec := json.NewDecoder(req.Body) dec.DisallowUnknownFields() @@ -51,6 +57,13 @@ func PostTestCredentialsHandlerFunc(log logging.Logger) http.HandlerFunc { return } + // Check if Snowflake connections are configured on the system. + // If so, Token Authentication won't work (browser can't reach internal URLs). + hasSnowflakeConnections := false + if conns, err := connections.List(); err == nil && len(conns) > 0 { + hasSnowflakeConnections = true + } + var user *connect.User var lastTestError error @@ -78,10 +91,11 @@ func PostTestCredentialsHandlerFunc(log logging.Logger) http.HandlerFunc { if err == nil { // If we succeeded, pass back what URL succeeded response := &PostTestCredentialsResponseBody{ - User: user, - Error: nil, - URL: discoveredURL, - ServerType: serverType, + User: user, + Error: nil, + URL: discoveredURL, + ServerType: serverType, + HasSnowflakeConnections: hasSnowflakeConnections, } w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) @@ -91,10 +105,11 @@ func PostTestCredentialsHandlerFunc(log logging.Logger) http.HandlerFunc { // failure after all attempts, return last error response := &PostTestCredentialsResponseBody{ - User: user, - Error: types.AsAgentError(lastTestError), - URL: b.URL, // pass back original URL - ServerType: serverType, + User: user, + Error: types.AsAgentError(lastTestError), + URL: b.URL, // pass back original URL + ServerType: serverType, + HasSnowflakeConnections: hasSnowflakeConnections, } w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/internal/services/api/post_test_credentials_test.go b/internal/services/api/post_test_credentials_test.go index 180a5a9ad3..1213e2901a 100644 --- a/internal/services/api/post_test_credentials_test.go +++ b/internal/services/api/post_test_credentials_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/posit-dev/publisher/internal/accounts" + "github.com/posit-dev/publisher/internal/api_client/auth/snowflake" "github.com/posit-dev/publisher/internal/clients/connect" "github.com/posit-dev/publisher/internal/events" "github.com/posit-dev/publisher/internal/logging" @@ -33,6 +34,22 @@ func (s *PostTestCredentialsHandlerSuite) SetupTest() { connectClientFactory = connect.NewConnectClient } +// mockNoSnowflakeConnections returns a mock that simulates no Snowflake connections +func mockNoSnowflakeConnections() *snowflake.MockConnections { + connections := &snowflake.MockConnections{} + connections.On("List").Return(map[string]*snowflake.Connection{}, nil) + return connections +} + +// mockWithSnowflakeConnections returns a mock that simulates Snowflake connections present +func mockWithSnowflakeConnections() *snowflake.MockConnections { + connections := &snowflake.MockConnections{} + connections.On("List").Return(map[string]*snowflake.Connection{ + "default": {}, + }, nil) + return connections +} + func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFunc() { log := logging.New() @@ -54,7 +71,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFunc() { connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { return client, nil } - handler := PostTestCredentialsHandlerFunc(log) + handler := PostTestCredentialsHandlerFunc(log, mockNoSnowflakeConnections()) handler(rec, req) s.Equal(http.StatusOK, rec.Result().StatusCode) @@ -65,6 +82,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFunc() { s.Equal(user, response.User) s.Equal("https://connect.example.com", response.URL) s.Nil(response.Error) + s.False(response.HasSnowflakeConnections) } func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncWithConnectCopiedURL() { @@ -93,7 +111,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncWith connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { return client, nil } - handler := PostTestCredentialsHandlerFunc(log) + handler := PostTestCredentialsHandlerFunc(log, mockNoSnowflakeConnections()) handler(rec, req) s.Equal(http.StatusOK, rec.Result().StatusCode) @@ -133,7 +151,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncWith connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { return client, nil } - handler := PostTestCredentialsHandlerFunc(log) + handler := PostTestCredentialsHandlerFunc(log, mockNoSnowflakeConnections()) handler(rec, req) s.Equal(http.StatusOK, rec.Result().StatusCode) @@ -163,7 +181,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncNoAp connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { return client, nil } - handler := PostTestCredentialsHandlerFunc(log) + handler := PostTestCredentialsHandlerFunc(log, mockNoSnowflakeConnections()) handler(rec, req) s.Equal(http.StatusOK, rec.Result().StatusCode) @@ -194,7 +212,7 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncBadA connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { return client, nil } - handler := PostTestCredentialsHandlerFunc(log) + handler := PostTestCredentialsHandlerFunc(log, mockNoSnowflakeConnections()) handler(rec, req) s.Equal(http.StatusOK, rec.Result().StatusCode) @@ -206,3 +224,40 @@ func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncBadA s.NotNil(response.Error) s.Equal("Test error from TestAuthentication.", response.Error.Message) } + +func (s *PostTestCredentialsHandlerSuite) TestPostTestCredentialsHandlerFuncWithSnowflakeConnections() { + log := logging.New() + + rec := httptest.NewRecorder() + req, err := http.NewRequest("POST", "/api/test-credentials", nil) + s.NoError(err) + + // Test with internal URL (like https://connect/) - should still return hasSnowflakeConnections=true + req.Body = io.NopCloser(strings.NewReader( + `{ + "url": "https://connect.example.com", + "apiKey": "0123456789abcdef0123456789abcdef" + }`)) + + client := connect.NewMockClient() + user := &connect.User{ + Email: "user@example.com", + } + client.On("TestAuthentication", mock.Anything).Return(user, nil) + connectClientFactory = func(account *accounts.Account, timeout time.Duration, emitter events.Emitter, log logging.Logger) (connect.APIClient, error) { + return client, nil + } + // Use mock with Snowflake connections present + handler := PostTestCredentialsHandlerFunc(log, mockWithSnowflakeConnections()) + handler(rec, req) + + s.Equal(http.StatusOK, rec.Result().StatusCode) + + var response PostTestCredentialsResponseBody + err = json.Unmarshal(rec.Body.Bytes(), &response) + s.NoError(err) + s.Equal(user, response.User) + s.Equal("https://connect.example.com", response.URL) + s.Nil(response.Error) + s.True(response.HasSnowflakeConnections) +} diff --git a/test/e2e/tests/credentials.cy.js b/test/e2e/tests/credentials.cy.js index 3b019239ef..8eeb969611 100644 --- a/test/e2e/tests/credentials.cy.js +++ b/test/e2e/tests/credentials.cy.js @@ -47,6 +47,23 @@ describe("Credentials Section", () => { "Select authentication method", ); + // Verify auth method options for Connect server: + // - First option should be Token Authentication (Recommended) + // - Second option should be API Key + // Note: Snowflake endpoints show different options (Snowflake Connection + API Key, no Token Auth) + // but we cannot E2E test that without Snowflake infrastructure + cy.get(".quick-input-list .monaco-list-row").should("have.length", 2); + + cy.get(".quick-input-list .monaco-list-row") + .eq(0) + .should("contain.text", "Token Authentication") + .and("contain.text", "Recommended"); + + cy.get(".quick-input-list .monaco-list-row") + .eq(1) + .should("contain.text", "API Key"); + + // Select API Key option to continue with the flow cy.get(".quick-input-list .monaco-list-row").eq(1).click(); cy.get(".quick-input-message").should(