Skip to content

Commit 403d377

Browse files
authored
Merge branch 'main' into mnv-drafts-url-update
2 parents b178e31 + 7775740 commit 403d377

File tree

6 files changed

+213
-87
lines changed

6 files changed

+213
-87
lines changed

docs/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Added support for the `--draft` option when deploying content,
1313
this allows to deploy a new bundle for the content without exposing
1414
it as a the activated one.
15+
- Improved support for Posit Connect deployments
16+
hosted in Snowpark Container Services.
1517

1618
### Fixed
1719

docs/overrides/partials/header.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
{% endif %}
8181
<div class="md-flex__cell md-flex__cell--shrink left-nav">
8282
<ul class="md-tabs__list">
83-
<li class="md-tabs__item"><a href="{{ base_url }}/changelog/" title="Release Notes" class="md-tabs__link md-source">Release Notes</a></li>
83+
<li class="md-tabs__item"><a href="{{ base_url }}/CHANGELOG/" title="Release Notes" class="md-tabs__link md-source">Release Notes</a></li>
8484
<li class="md-tabs__item"><a href="https://support.posit.co/hc/en-us" title="Posit Support" class="md-tabs__link md-source">Help</a></li>
8585
</ul>
8686
</div>

rsconnect/api.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
TaskStatusV1,
7777
UserRecord,
7878
)
79-
from .snowflake import generate_jwt, get_connection_parameters
79+
from .snowflake import generate_jwt, get_parameters
8080
from .timeouts import get_task_timeout, get_task_timeout_help_message
8181

8282
if TYPE_CHECKING:
@@ -260,40 +260,62 @@ def __init__(
260260
self.bootstrap_jwt = None
261261

262262
def token_endpoint(self) -> str:
263-
params = get_connection_parameters(self.snowflake_connection_name)
263+
params = get_parameters(self.snowflake_connection_name)
264264

265265
if params is None:
266266
raise RSConnectException("No Snowflake connection found.")
267267

268268
return "https://{}.snowflakecomputing.com/".format(params["account"])
269269

270-
def fmt_payload(self) -> str:
271-
params = get_connection_parameters(self.snowflake_connection_name)
270+
def fmt_payload(self):
271+
params = get_parameters(self.snowflake_connection_name)
272272

273273
if params is None:
274274
raise RSConnectException("No Snowflake connection found.")
275275

276-
spcs_url = urlparse(self.url)
277-
scope = "session:role:{} {}".format(params["role"], spcs_url.netloc)
278-
jwt = generate_jwt(self.snowflake_connection_name)
279-
grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer"
280-
281-
payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type}
282-
payload = urlencode(payload)
283-
return payload
276+
authenticator = params.get("authenticator")
277+
if authenticator == "SNOWFLAKE_JWT":
278+
spcs_url = urlparse(self.url)
279+
scope = (
280+
"session:role:{} {}".format(params["role"], spcs_url.netloc) if params.get("role") else spcs_url.netloc
281+
)
282+
jwt = generate_jwt(self.snowflake_connection_name)
283+
grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer"
284+
285+
payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type}
286+
payload = urlencode(payload)
287+
return {
288+
"body": payload,
289+
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
290+
"path": "/oauth/token",
291+
}
292+
elif authenticator == "oauth":
293+
payload = {
294+
"data": {
295+
"AUTHENTICATOR": "OAUTH",
296+
"TOKEN": params["token"],
297+
}
298+
}
299+
return {
300+
"body": payload,
301+
"headers": {
302+
"Content-Type": "application/json",
303+
"Authorization": "Bearer %s" % params["token"],
304+
"X-Snowflake-Authorization-Token-Type": "OAUTH",
305+
},
306+
"path": "/session/v1/login-request",
307+
}
308+
else:
309+
raise NotImplementedError("Unsupported authenticator for SPCS Connect: %s" % authenticator)
284310

285311
def exchange_token(self) -> str:
286312
try:
287313
server = HTTPServer(url=self.token_endpoint())
288314
payload = self.fmt_payload()
289315

290316
response = server.request(
291-
method="POST",
292-
path="/oauth/token",
293-
body=payload,
294-
headers={"Content-Type": "application/x-www-form-urlencoded"},
317+
method="POST", **payload # type: ignore[arg-type] # fmt_payload returns a dict with body and headers
295318
)
296-
297319
response = cast(HTTPResponse, response)
298320

299321
# borrowed from AbstractRemoteServer.handle_bad_response
@@ -313,10 +335,24 @@ def exchange_token(self) -> str:
313335
if not response.response_body:
314336
raise RSConnectException("Token exchange returned empty response")
315337

316-
# Ensure we return a string
338+
# Ensure response body is decoded to string on the object
317339
if isinstance(response.response_body, bytes):
318-
return response.response_body.decode("utf-8")
319-
return response.response_body
340+
response.response_body = response.response_body.decode("utf-8")
341+
342+
# Try to parse as JSON first
343+
try:
344+
import json
345+
346+
json_data = json.loads(response.response_body)
347+
# If it's JSON, extract the token from data.token
348+
if isinstance(json_data, dict) and "data" in json_data and "token" in json_data["data"]:
349+
return json_data["data"]["token"]
350+
else:
351+
# JSON format doesn't match expected structure, return raw response
352+
return response.response_body
353+
except (json.JSONDecodeError, ValueError):
354+
# Not JSON, return the raw response body
355+
return response.response_body
320356

321357
except RSConnectException as e:
322358
raise RSConnectException(f"Failed to exchange Snowflake token: {str(e)}") from e

rsconnect/snowflake.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,43 @@ def list_connections() -> List[Dict[str, Any]]:
4040
raise RSConnectException("Could not list snowflake connections.")
4141

4242

43-
def get_connection_parameters(name: Optional[str] = None) -> Optional[Dict[str, Any]]:
43+
def get_parameters(name: Optional[str] = None) -> Dict[str, Any]:
44+
"""Get Snowflake connection parameters.
45+
Args:
46+
name: The name of the connection to retrieve. If None, returns the default connection.
47+
48+
Returns:
49+
A dictionary of connection parameters.
50+
"""
51+
try:
52+
from snowflake.connector.config_manager import CONFIG_MANAGER
53+
except ImportError:
54+
raise RSConnectException("snowflake-cli is not installed.")
55+
try:
56+
connections = CONFIG_MANAGER["connections"]
57+
if not isinstance(connections, dict):
58+
raise TypeError("connections is not a dictionary")
59+
60+
if name is None:
61+
def_connection_name = CONFIG_MANAGER["default_connection_name"]
62+
if not isinstance(def_connection_name, str):
63+
raise TypeError("default_connection_name is not a string")
64+
params = connections[def_connection_name]
65+
else:
66+
params = connections[name]
4467

45-
connection_list = list_connections()
46-
# return parameters for default connection if configured
47-
# otherwise return named connection
68+
if not isinstance(params, dict):
69+
raise TypeError("connection parameters is not a dictionary")
4870

49-
if not connection_list:
50-
raise RSConnectException("No Snowflake connections found.")
71+
return {str(k): v for k, v in params.items()}
5172

52-
try:
53-
if not name:
54-
return next((x["parameters"] for x in connection_list if x.get("is_default")), None)
55-
else:
56-
return next((x["parameters"] for x in connection_list if x.get("connection_name") == name))
57-
except StopIteration:
58-
raise RSConnectException(f"No Snowflake connection found with name '{name}'.")
73+
except (KeyError, AttributeError) as e:
74+
raise RSConnectException(f"Could not get Snowflake connection: {e}")
5975

6076

6177
def generate_jwt(name: Optional[str] = None) -> str:
6278

63-
_ = get_connection_parameters(name)
79+
_ = get_parameters(name)
6480
connection_name = "" if name is None else name
6581

6682
try:

tests/test_api.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -526,41 +526,48 @@ def test_token_endpoint(self, mock_token_endpoint):
526526
endpoint = server.token_endpoint()
527527
assert endpoint == "https://example.snowflakecomputing.com/"
528528

529-
@patch("rsconnect.api.get_connection_parameters")
530-
def test_token_endpoint_with_account(self, mock_get_connection_parameters):
529+
@patch("rsconnect.api.get_parameters")
530+
def test_token_endpoint_with_account(self, mock_get_parameters):
531531
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
532-
mock_get_connection_parameters.return_value = {"account": "test_account"}
532+
mock_get_parameters.return_value = {"account": "test_account"}
533533
endpoint = server.token_endpoint()
534534
assert endpoint == "https://test_account.snowflakecomputing.com/"
535-
mock_get_connection_parameters.assert_called_once_with("example_connection")
535+
mock_get_parameters.assert_called_once_with("example_connection")
536536

537-
@patch("rsconnect.api.get_connection_parameters")
538-
def test_token_endpoint_with_none_params(self, mock_get_connection_parameters):
537+
@patch("rsconnect.api.get_parameters")
538+
def test_token_endpoint_with_none_params(self, mock_get_parameters):
539539
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
540-
mock_get_connection_parameters.return_value = None
540+
mock_get_parameters.return_value = None
541541
with pytest.raises(RSConnectException, match="No Snowflake connection found."):
542542
server.token_endpoint()
543543

544-
@patch("rsconnect.api.get_connection_parameters")
545-
def test_fmt_payload(self, mock_get_connection_parameters):
544+
@patch("rsconnect.api.get_parameters")
545+
def test_fmt_payload(self, mock_get_parameters):
546546
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
547-
mock_get_connection_parameters.return_value = {"account": "test_account", "role": "test_role"}
547+
mock_get_parameters.return_value = {
548+
"account": "test_account",
549+
"role": "test_role",
550+
"authenticator": "SNOWFLAKE_JWT",
551+
}
548552

549553
with patch("rsconnect.api.generate_jwt") as mock_generate_jwt:
550554
mock_generate_jwt.return_value = "mocked_jwt"
551555
payload = server.fmt_payload()
552556

553-
assert "scope=session%3Arole%3Atest_role+spcs.example.com" in payload
554-
assert "assertion=mocked_jwt" in payload
555-
assert "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" in payload
557+
assert (
558+
payload["body"]
559+
== "scope=session%3Arole%3Atest_role+spcs.example.com&assertion=mocked_jwt&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" # noqa
560+
)
561+
assert payload["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
562+
assert payload["path"] == "/oauth/token"
556563

557-
mock_get_connection_parameters.assert_called_once_with("example_connection")
564+
mock_get_parameters.assert_called_once_with("example_connection")
558565
mock_generate_jwt.assert_called_once_with("example_connection")
559566

560-
@patch("rsconnect.api.get_connection_parameters")
561-
def test_fmt_payload_with_none_params(self, mock_get_connection_parameters):
567+
@patch("rsconnect.api.get_parameters")
568+
def test_fmt_payload_with_none_params(self, mock_get_parameters):
562569
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
563-
mock_get_connection_parameters.return_value = None
570+
mock_get_parameters.return_value = None
564571
with pytest.raises(RSConnectException, match="No Snowflake connection found."):
565572
server.fmt_payload()
566573

@@ -579,7 +586,11 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc
579586

580587
# Mock the token endpoint and payload
581588
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
582-
mock_fmt_payload.return_value = "mocked_payload"
589+
mock_fmt_payload.return_value = {
590+
"body": "mocked_payload_body",
591+
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
592+
"path": "/oauth/token",
593+
}
583594

584595
# Call the method
585596
result = server.exchange_token()
@@ -589,9 +600,9 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc
589600
mock_http_server.assert_called_once_with(url="https://example.snowflakecomputing.com/")
590601
mock_server_instance.request.assert_called_once_with(
591602
method="POST",
592-
path="/oauth/token",
593-
body="mocked_payload",
603+
body="mocked_payload_body",
594604
headers={"Content-Type": "application/x-www-form-urlencoded"},
605+
path="/oauth/token",
595606
)
596607

597608
@patch("rsconnect.api.HTTPServer")
@@ -610,7 +621,11 @@ def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint
610621

611622
# Mock the token endpoint and payload
612623
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
613-
mock_fmt_payload.return_value = "mocked_payload"
624+
mock_fmt_payload.return_value = {
625+
"body": "mocked_payload_body",
626+
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
627+
"path": "/oauth/token",
628+
}
614629

615630
# Call the method and verify it raises the expected exception
616631
with pytest.raises(RSConnectException, match="Failed to exchange Snowflake token"):
@@ -631,7 +646,11 @@ def test_exchange_token_empty_response(self, mock_fmt_payload, mock_token_endpoi
631646

632647
# Mock the token endpoint and payload
633648
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
634-
mock_fmt_payload.return_value = "mocked_payload"
649+
mock_fmt_payload.return_value = {
650+
"body": "mocked_payload_body",
651+
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
652+
"path": "/oauth/token",
653+
}
635654

636655
# Call the method and verify it raises the expected exception
637656
with pytest.raises(

0 commit comments

Comments
 (0)