diff --git a/tosfs/core.py b/tosfs/core.py index 9a7ac25..f36bb83 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -227,6 +227,7 @@ def __init__( Additional arguments. """ + self.endpoint = endpoint if endpoint_url is not None: warnings.warn( "The 'endpoint_url' parameter is deprecated and will be removed" @@ -234,7 +235,7 @@ def __init__( DeprecationWarning, stacklevel=2, ) - endpoint = endpoint_url + self.endpoint = endpoint_url self.tos_client = tos.TosClientV2( key, @@ -2268,7 +2269,7 @@ def _split_path(self, path: str) -> Tuple[str, str, Optional[str]]: key, _, version_id = keypart.partition("?versionId=") if self.tag_enabled: - self.bucket_tag_mgr.add_bucket_tag(bucket) + self.bucket_tag_mgr.add_bucket_tag(bucket, self.endpoint) return ( bucket, diff --git a/tosfs/tag.py b/tosfs/tag.py index 051e294..c5a5446 100644 --- a/tosfs/tag.py +++ b/tosfs/tag.py @@ -163,9 +163,12 @@ def __init__( secret: Optional[str], session_token: Optional[str], region: str, + endpoint: Optional[str] = None, ) -> None: """Init BucketTagAction.""" - super().__init__(self.get_service_info(region), self.get_api_info()) + self._region = region + self._endpoint = endpoint + super().__init__(self.get_service_info(), self.get_api_info()) if key: self.set_ak(key) @@ -180,17 +183,38 @@ def get_api_info() -> dict: """Get api info.""" return api_info - @staticmethod - def get_service_info(region: str) -> ServiceInfo: + def get_service_info(self) -> ServiceInfo: """Get service info.""" + region = self._region service_info = service_info_map.get(region) - if service_info: - from tosfs.core import logger + if not service_info: + raise Exception(f"Do not support region: {region}") - logger.debug(f"The service name is : {service_info.credentials.service}") - return service_info + from tosfs.core import logger - raise Exception(f"Do not support region: {region}") + credentials = service_info.credentials + logger.debug(f"The service name is : {credentials.service}") + + endpoint = self._endpoint or "" + host = OPEN_API_HOST + scheme = "https" + + if endpoint: + if endpoint.endswith(("ivolces.com", "volces.com")): + host = f"emr.{region}.volcengineapi.com" + elif endpoint.endswith(("ibytepluses.com", "bytepluses.com")): + host = f"emr.{region}.byteplusapi.com" + elif endpoint.endswith(("ibytepluses.com.cn", "byteplus.com.cn")): + host = f"emr.{region}.byteplusapi.com.cn" + + return ServiceInfo( + host, + {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE}, + credentials, + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + scheme, + ) def put_bucket_tag(self, bucket: str) -> tuple[str, bool]: """Put tag for bucket.""" @@ -269,7 +293,7 @@ def __init__( self.key, self.secret, self.session_token, actual_region ) - def add_bucket_tag(self, bucket: str) -> None: + def add_bucket_tag(self, bucket: str, endpoint: Optional[str] = None) -> None: """Add tag for bucket.""" from tosfs.core import logger @@ -292,9 +316,18 @@ def add_bucket_tag(self, bucket: str) -> None: need_tag_buckets = collect_bucket_set - self.cached_bucket_set logger.debug(f"Need to tag buckets : {need_tag_buckets}") - for res in self.executor.map( - self.bucket_tag_service.put_bucket_tag, need_tag_buckets - ): + def _add_bucket_tag(target_bucket: str) -> tuple[str, bool]: + """Create a BucketTagAction with the endpoint for this call.""" + action = BucketTagAction( + self.key, + self.secret, + self.session_token, + getattr(self, "actual_region", self.region), + endpoint, + ) + return action.put_bucket_tag(target_bucket) + + for res in self.executor.map(_add_bucket_tag, need_tag_buckets): if res[1]: self.cached_bucket_set.add(res[0]) diff --git a/tosfs/tests/test_tag.py b/tosfs/tests/test_tag.py index 7230932..2071912 100644 --- a/tosfs/tests/test_tag.py +++ b/tosfs/tests/test_tag.py @@ -36,7 +36,7 @@ def test_bucket_tag_action(tosfs, bucket, temporary_workspace): return tag_mgr.cached_bucket_set = set() - tag_mgr.add_bucket_tag(bucket) + tag_mgr.add_bucket_tag(bucket, "https://example.com/tag-service") sleep(10) assert os.path.exists(TAGGED_BUCKETS_FILE) with open(TAGGED_BUCKETS_FILE, "r") as f: