Skip to content

Commit

Permalink
Support tag via session token (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua authored Oct 9, 2024
1 parent 9662ce9 commit 070d844
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
19 changes: 13 additions & 6 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
key: str = "",
secret: str = "",
region: Optional[str] = None,
security_token: Optional[str] = None,
session_token: Optional[str] = None,
max_retry_num: int = 20,
max_connections: int = 1024,
connection_timeout: int = 10,
Expand Down Expand Up @@ -144,8 +144,8 @@ def __init__(
The secret access key(sk) to access the TOS service.
region : str, optional
The region of the TOS service.
security_token : str, optional
The temporary security token to access the TOS service.
session_token : str, optional
The temporary session token to access the TOS service.
max_retry_num : int, optional
The maximum number of retries for a failed request (default is 20).
max_connections : int, optional
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(
secret,
endpoint_url,
region,
security_token=security_token,
security_token=session_token,
max_retry_count=0,
max_connections=max_connections,
connection_time=connection_timeout,
Expand Down Expand Up @@ -2012,10 +2012,17 @@ def _init_tag_manager(self) -> None:
if isinstance(auth, CredentialProviderAuth):
credentials = auth.credentials_provider.get_credentials()
self.bucket_tag_mgr = BucketTagMgr(
credentials.get_ak(), credentials.get_sk(), auth.region
credentials.get_ak(),
credentials.get_sk(),
credentials.get_security_token(),
auth.region,
)
else:
raise TosfsError("Currently only support CredentialProviderAuth type")
raise TosfsError(
"Currently only support CredentialProviderAuth type, "
"please check if you set (ak & sk) or (session token) "
"correctly."
)

@staticmethod
def _fill_dir_info(
Expand Down
33 changes: 27 additions & 6 deletions tosfs/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, Optional

from volcengine.ApiInfo import ApiInfo
from volcengine.base.Service import Service
Expand Down Expand Up @@ -125,11 +125,23 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
BucketTagAction._instance = object.__new__(cls)
return BucketTagAction._instance

def __init__(self, key: str, secret: str, region: str) -> None:
def __init__(
self,
key: Optional[str],
secret: Optional[str],
session_token: Optional[str],
region: str,
) -> None:
"""Init BucketTagAction."""
super().__init__(self.get_service_info(region), self.get_api_info())
self.set_ak(key)
self.set_sk(secret)
if key:
self.set_ak(key)

if secret:
self.set_sk(secret)

if session_token:
self.set_session_token(session_token)

@staticmethod
def get_api_info() -> dict:
Expand Down Expand Up @@ -192,14 +204,23 @@ def get_instance(*args: Any, **kwargs: Any) -> Any:
class BucketTagMgr:
"""BucketTagMgr is a class to manage the tag of bucket."""

def __init__(self, key: str, secret: str, region: str):
def __init__(
self,
key: Optional[str],
secret: Optional[str],
session_token: Optional[str],
region: str,
):
"""Init BucketTagMgr."""
self.executor = ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE)
self.cached_bucket_set: set = set()
self.key = key
self.secret = secret
self.session_token = session_token
self.region = region
self.bucket_tag_service = BucketTagAction(self.key, self.secret, self.region)
self.bucket_tag_service = BucketTagAction(
self.key, self.secret, self.session_token, self.region
)

def add_bucket_tag(self, bucket: str) -> None:
"""Add tag for bucket."""
Expand Down

0 comments on commit 070d844

Please sign in to comment.