From 070d844aae63c5bb4b662bfef4c6429fb537e3f2 Mon Sep 17 00:00:00 2001 From: vinoyang Date: Wed, 9 Oct 2024 11:51:33 +0800 Subject: [PATCH] Support tag via session token (#173) --- tosfs/core.py | 19 +++++++++++++------ tosfs/tag.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/tosfs/core.py b/tosfs/core.py index d7dd728..4c8e865 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -105,7 +105,7 @@ def __init__( key: str = "", secret: str = "", region: Optional[str] = None, - security_token: Optional[str] = None, + session_token: Optional[str] = None, max_retry_num: int = 20, max_connections: int = 1024, connection_timeout: int = 10, @@ -144,8 +144,8 @@ def __init__( The secret access key(sk) to access the TOS service. region : str, optional The region of the TOS service. - security_token : str, optional - The temporary security token to access the TOS service. + session_token : str, optional + The temporary session token to access the TOS service. max_retry_num : int, optional The maximum number of retries for a failed request (default is 20). max_connections : int, optional @@ -231,7 +231,7 @@ def __init__( secret, endpoint_url, region, - security_token=security_token, + security_token=session_token, max_retry_count=0, max_connections=max_connections, connection_time=connection_timeout, @@ -2012,10 +2012,17 @@ def _init_tag_manager(self) -> None: if isinstance(auth, CredentialProviderAuth): credentials = auth.credentials_provider.get_credentials() self.bucket_tag_mgr = BucketTagMgr( - credentials.get_ak(), credentials.get_sk(), auth.region + credentials.get_ak(), + credentials.get_sk(), + credentials.get_security_token(), + auth.region, ) else: - raise TosfsError("Currently only support CredentialProviderAuth type") + raise TosfsError( + "Currently only support CredentialProviderAuth type, " + "please check if you set (ak & sk) or (session token) " + "correctly." + ) @staticmethod def _fill_dir_info( diff --git a/tosfs/tag.py b/tosfs/tag.py index 27b4f98..4368e08 100644 --- a/tosfs/tag.py +++ b/tosfs/tag.py @@ -21,7 +21,7 @@ import os import threading from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Optional from volcengine.ApiInfo import ApiInfo from volcengine.base.Service import Service @@ -125,11 +125,23 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: BucketTagAction._instance = object.__new__(cls) return BucketTagAction._instance - def __init__(self, key: str, secret: str, region: str) -> None: + def __init__( + self, + key: Optional[str], + secret: Optional[str], + session_token: Optional[str], + region: str, + ) -> None: """Init BucketTagAction.""" super().__init__(self.get_service_info(region), self.get_api_info()) - self.set_ak(key) - self.set_sk(secret) + if key: + self.set_ak(key) + + if secret: + self.set_sk(secret) + + if session_token: + self.set_session_token(session_token) @staticmethod def get_api_info() -> dict: @@ -192,14 +204,23 @@ def get_instance(*args: Any, **kwargs: Any) -> Any: class BucketTagMgr: """BucketTagMgr is a class to manage the tag of bucket.""" - def __init__(self, key: str, secret: str, region: str): + def __init__( + self, + key: Optional[str], + secret: Optional[str], + session_token: Optional[str], + region: str, + ): """Init BucketTagMgr.""" self.executor = ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) self.cached_bucket_set: set = set() self.key = key self.secret = secret + self.session_token = session_token self.region = region - self.bucket_tag_service = BucketTagAction(self.key, self.secret, self.region) + self.bucket_tag_service = BucketTagAction( + self.key, self.secret, self.session_token, self.region + ) def add_bucket_tag(self, bucket: str) -> None: """Add tag for bucket."""