Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__/
.DS_Store
.idea/
.vscode/
.lingma/


# C extensions
Expand Down
1 change: 0 additions & 1 deletion pai/api/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def list_ecs_specs(
page_size=10,
sort_by="Gpu",
) -> PaginatedResult:

"""List EcsSpecs that DLC service provided."""

request = ListEcsSpecsRequest(
Expand Down
8 changes: 5 additions & 3 deletions pai/api/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ def get_default_workspace(self) -> Dict[str, Any]:
def list_configs(self, workspace_id, config_keys: Union[List[str], str]) -> Dict:
"""List configs used in the Workspace."""
request = ListConfigsRequest(
config_keys=",".join(config_keys)
if isinstance(config_keys, (tuple, list))
else config_keys,
config_keys=(
",".join(config_keys)
if isinstance(config_keys, (tuple, list))
else config_keys
),
)

resp: ListConfigsResponseBody = self._do_request(
Expand Down
20 changes: 19 additions & 1 deletion pai/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,33 @@ def is_dataset_id(item: str) -> bool:
"""Return True if given input is a dataset ID.

Args:
item (str): user input dataset ID.
item (str): user input dataset ID or Dataset ID and dataset version, separated by a slash.

Examples:
>>> is_dataset_id('d-ybko3rap60c4gs9flc')
True
>>> is_dataset_id('d-ybko3rap60c4gs9flc/v1')
True
"""
return item.startswith("d-")


def is_nas_uri(uri: Union[str, bytes]) -> bool:
"""Determines whether the given uri is a NAS uri.

Args:
uri (Union[str, bytes]): A string in NAS URI schema:
nas://29**d-b12****446.cn-hangzhou.nas.aliyuncs.com/data/path/
nas://29****123-y**r.cn-hangzhou.extreme.nas.aliyuncs.com/data/path/


Returns:
bool: True if the given uri is a NAS uri, else False.

"""
return bool(uri and isinstance(uri, (str, bytes)) and str(uri).startswith("nas://"))


@lru_cache()
def is_domain_connectable(domain: str, port: int = 80, timeout: int = 1) -> bool:
"""Check if the domain is connectable."""
Expand Down
59 changes: 38 additions & 21 deletions pai/dsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from .common.logging import get_logger
from .common.oss_utils import OssUriObj, is_oss_uri
from .common.utils import is_dataset_id
from .common.utils import is_dataset_id, is_nas_uri
from .libs.alibabacloud_pai_dsw20220101.models import (
GetInstanceRequest,
GetInstanceResponse,
GetInstanceResponseBody,
UpdateInstanceRequest,
Expand Down Expand Up @@ -176,7 +177,8 @@ def __init__(self, instance_id: str):
def _get_instance_info(self):
session = get_default_session()
resp: GetInstanceResponse = session._acs_dsw_client.get_instance(
self.instance_id
self.instance_id,
request=GetInstanceRequest(),
)
return resp.body

Expand All @@ -194,11 +196,10 @@ def default_dynamic_mount_path(self) -> Optional[str]:
Returns:
str: The default dynamic mount path of the DSW Instance.
"""
if (
not self._instance_info.dynamic_mount.enable
or not self._instance_info.dynamic_mount.mount_points
):
return
if not self._instance_info.dynamic_mount.enable:
return None
if not self._instance_info.dynamic_mount.mount_points:
return "/mnt/dynamic"
return self._instance_info.dynamic_mount.mount_points[0].root_path

def mount(
Expand All @@ -212,9 +213,9 @@ def mount(
Dynamic mount a data source to the DSW Instance.

Args:
source (str): The source to be mounted, can be a dataset id or an OSS uri.
source (str): The source to be mounted, can be a dataset id or OSS/NAS uri.
mount_point (str): Target mount point in the instance, if not specified, the
mount point be generate with given source under the default mount point.
mount point be generated with given source under the default mount point.
options (str): Options that apply to when mount a data source, can not be
specified with option_type.
option_type(str): Preset data source mount options, can not be specified with
Expand All @@ -233,12 +234,6 @@ def mount(
self.instance_id
)
)
if not self._instance_info.dynamic_mount.mount_points:
raise RuntimeError(
"No dynamic mount points found for the DSW instance: {}".format(
self.instance_id
)
)

sess = get_default_session()
default_root_path = self.default_dynamic_mount_path()
Expand All @@ -251,30 +246,48 @@ def mount(
_, dir_path, _ = obj.parse_object_key()
uri = f"oss://{obj.bucket_name}.{obj.endpoint}{dir_path}"
dataset_id = None
dataset_version = None
elif is_nas_uri(source):
uri = source
dataset_id = None
dataset_version = None
else:
dataset_id = source
uri = None
if "/" in dataset_id:
dataset_id, dataset_version = dataset_id.split("/", 1)
else:
dataset_version = "v1"

if not is_oss_uri(source) and not is_dataset_id(source):
raise ValueError("Source must be oss uri or dataset id")
if (
not is_oss_uri(source)
and not is_nas_uri(source)
and not is_dataset_id(source)
):
raise ValueError("Source must be oss uri or nas uri or dataset id")

if not mount_point:
if is_oss_uri(source):
obj = OssUriObj(source)
mount_point = f"{obj.bucket_name}/{obj.object_key}"
elif is_nas_uri(source):
raise ValueError("Mount point is required for nas url.")
else:
mount_point = source
if not posixpath.isabs(mount_point):
mount_point = posixpath.join(default_root_path, mount_point)

resp: GetInstanceResponse = sess._acs_dsw_client.get_instance(self.instance_id)
resp: GetInstanceResponse = sess._acs_dsw_client.get_instance(
self.instance_id, request=GetInstanceRequest()
)
datasets = [
UpdateInstanceRequestDatasets().from_map(ds.to_map())
for ds in resp.body.datasets
]
datasets.append(
UpdateInstanceRequestDatasets(
dataset_id=dataset_id,
dataset_version=dataset_version,
dynamic=True,
mount_path=mount_point,
option_type=option_type,
Expand All @@ -285,9 +298,10 @@ def mount(
request = UpdateInstanceRequest(
datasets=datasets,
)
sess._acs_dsw_client.update_instance(
update_resp = sess._acs_dsw_client.update_instance(
instance_id=self.instance_id, request=request
)
print("Mount succeed, request id: {}".format(update_resp.body.request_id))
return mount_point

def unmount(self, mount_point: str) -> None:
Expand All @@ -302,7 +316,9 @@ def unmount(self, mount_point: str) -> None:
"""
sess = get_default_session()

resp: GetInstanceResponse = sess._acs_dsw_client.get_instance(self.instance_id)
resp: GetInstanceResponse = sess._acs_dsw_client.get_instance(
self.instance_id, request=GetInstanceRequest()
)
datasets = [
UpdateInstanceRequestDatasets().from_map(ds.to_map())
for ds in resp.body.datasets
Expand All @@ -328,6 +344,7 @@ def unmount(self, mount_point: str) -> None:
)
if not request.datasets:
request.disassociate_datasets = True
sess._acs_dsw_client.update_instance(
update_resp = sess._acs_dsw_client.update_instance(
instance_id=self.instance_id, request=request
)
print("Unmount succeed, request id: {}".format(update_resp.body.request_id))
2 changes: 1 addition & 1 deletion pai/libs/alibabacloud_pai_dsw20220101/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.4'
__version__ = '1.5.11'
Loading