From 649203b27055d01eaaefb752fd6fcbd14e6dcdcb Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 14 Nov 2024 09:26:03 +0800 Subject: [PATCH] Provide a FileCredentialsProvider for assume role --- tosfs/certification.py | 184 ++++++++++++++++++++++++++++++ tosfs/exceptions.py | 11 +- tosfs/tests/test_certification.py | 182 +++++++++++++++++++++++++++++ 3 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 tosfs/certification.py create mode 100644 tosfs/tests/test_certification.py diff --git a/tosfs/certification.py b/tosfs/certification.py new file mode 100644 index 0000000..f3af727 --- /dev/null +++ b/tosfs/certification.py @@ -0,0 +1,184 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""It contains everything about certification via a file-based provider.""" + +import threading +from datetime import datetime +from typing import Optional +from xml.etree import ElementTree + +from tos.credential import Credentials, CredentialsProvider + +from tosfs.core import logger +from tosfs.exceptions import TosfsCertificationError + +CERTIFICATION_REFRESH_INTERVAL_MINUTES = 60 +CERTIFICATION_MAX_VALID_PERIOD_HOURS = 12 + + +class FileCredentialsProvider(CredentialsProvider): + """The class provides the credentials from a file. + + The file should be in the format of: + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + + + It can only receive a file path which exists in the local file system. + + Note : + 1. The default refresh interval is 60 minutes. + 2. The maximum valid period for provisional certification is 12 hours. + + This provider will cache the credentials and refresh them every 60 minutes. + And note that, it only reads the credentials from the file and refreshes itself, + to guarantee the credentials are always up-to-date, + we need a service update it internally. + + Examples + -------- + >>> from tosfs import TosFileSystem + >>> tosfs = TosFileSystem( + >>> endpoint="tos-cn-beijing.volcengine.com", + >>> regions="cn-bejing", + >>> credentials_provider=FileCredentialsProvider("dummy_path")) + >>> tosfs.ls("tos://bucket/path") + + """ + + def __init__( + self, + file_path: str, + refresh_interval_min: int = CERTIFICATION_REFRESH_INTERVAL_MINUTES, + ) -> None: + """Initialize the FileCredentialsProvider.""" + self.file_path = file_path + self.refresh_interval = refresh_interval_min + if self.refresh_interval <= 0: + logger.warning( + f"Invalid refresh interval {self.refresh_interval}, " + f"set to default value: 60 minutes" + ) + self.refresh_interval = CERTIFICATION_REFRESH_INTERVAL_MINUTES + if self.refresh_interval > CERTIFICATION_MAX_VALID_PERIOD_HOURS * 60: + logger.warning( + f"Invalid refresh interval {self.refresh_interval}, " + f"set to maximum value: 60 minutes" + ) + self.refresh_interval = CERTIFICATION_REFRESH_INTERVAL_MINUTES + self.prev_refresh_time: Optional[datetime] = None + self.credentials = None + self._lock = threading.Lock() + + def get_credentials(self) -> Credentials: + """Get the credentials from the file. + + Returns + ------- + Credentials: The credentials object. + + Raises + ------ + TosfsCertificationError: If the credentials cannot be retrieved. + + """ + res = self._try_get_credentials() + if res is not None: + return res + with self._lock: + try: + res = self._try_get_credentials() + if res is not None: + return res + + with open(self.file_path, "r") as f: + logger.debug( + f"Trying to refresh the credentials from file: " + f"{self.file_path}" + ) + tree = ElementTree.parse(f) # noqa S314 + root = tree.getroot() + + access_key_element = root.find( + ".//property[name='fs.tos.access-key-id']/value" + ) + secret_key_element = root.find( + ".//property[name='fs.tos.secret-access-key']/value" + ) + session_token_element = root.find( + ".//property[name='fs.tos.session-token']/value" + ) + + access_key = ( + access_key_element.text + if access_key_element is not None + else None + ) + secret_key = ( + secret_key_element.text + if secret_key_element is not None + else None + ) + session_token = ( + session_token_element.text + if session_token_element is not None + else None + ) + + if None in ( + access_key, + secret_key, + session_token, + ): + raise TosfsCertificationError( + "Missing required credential elements in the file" + ) + + self.prev_refresh_time = datetime.now() + self.credentials = Credentials( + access_key, secret_key, session_token + ) + logger.debug( + f"Successfully refreshed the credentials from file: " + f"{self.file_path}" + ) + + return self.credentials + except Exception as e: + raise TosfsCertificationError("Get certification error: ") from e + + def _try_get_credentials(self) -> Optional[Credentials]: + if self.prev_refresh_time is None or self.credentials is None: + return None + if ( + datetime.now() - self.prev_refresh_time + ).total_seconds() / 60 > CERTIFICATION_REFRESH_INTERVAL_MINUTES: + logger.debug( + f"Credentials are expired, " + f"will try to refresh the credentials from file: " + f"{self.file_path}" + ) + return None + return self.credentials diff --git a/tosfs/exceptions.py b/tosfs/exceptions.py index dd86ddf..f89dc7c 100644 --- a/tosfs/exceptions.py +++ b/tosfs/exceptions.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""iT contains exceptions definition for the tosfs package.""" + +"""It contains exceptions definition for the tosfs package.""" class TosfsError(Exception): @@ -20,3 +21,11 @@ class TosfsError(Exception): def __init__(self, message: str): """Initialize the base class for all exceptions in the tosfs package.""" super().__init__(message) + + +class TosfsCertificationError(TosfsError): + """Exception class for certification related exception.""" + + def __init__(self, message: str): + """Initialize the exception class for certification related exception.""" + super().__init__(message) diff --git a/tosfs/tests/test_certification.py b/tosfs/tests/test_certification.py new file mode 100644 index 0000000..26a9244 --- /dev/null +++ b/tosfs/tests/test_certification.py @@ -0,0 +1,182 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +from unittest.mock import mock_open, patch + +import pytest + +from tosfs.certification import FileCredentialsProvider +from tosfs.exceptions import TosfsCertificationError + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data=""" + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + +""", +) +def test_get_credentials(mock_file): + provider = FileCredentialsProvider("dummy_path") + credentials = provider.get_credentials() + assert credentials.access_key_id == "access_key" + assert credentials.access_key_secret == "secret_key" # noqa S105 + assert credentials.security_token == "session_token" # noqa S105 + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data=""" + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + +""", +) +def test_get_credentials_with_refresh(mock_file): + provider = FileCredentialsProvider("dummy_path") + provider.prev_refresh_time = datetime.now() - timedelta(minutes=61) + credentials = provider.get_credentials() + assert credentials.access_key_id == "access_key" + assert credentials.access_key_secret == "secret_key" # noqa S105 + assert credentials.security_token == "session_token" # noqa S105 + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data=""" + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + +""", +) +def test_get_credentials_error(mock_file): + provider = FileCredentialsProvider("dummy_path") + with patch("xml.etree.ElementTree.parse", side_effect=Exception("Parse error")): + with pytest.raises(TosfsCertificationError): + provider.get_credentials() + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data=""" + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + +""", +) +def test_wrong_file_format_error(mock_file): + provider = FileCredentialsProvider("dummy_path") + provider.prev_refresh_time = datetime.now() - timedelta(minutes=61) + with pytest.raises(TosfsCertificationError): + provider.get_credentials() + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data=""" + + + fs.tos.access-key-id + access_key + + + fs.tos.secret-access-key + secret_key + + + fs.tos.session-token + session_token + + +""", +) +def test_no_refresh_within_interval(mock_file): + provider = FileCredentialsProvider("dummy_path") + provider.get_credentials() + + provider.prev_refresh_time = datetime.now() - timedelta(minutes=30) + + mock_file().read_data = """ + + + fs.tos.access-key-id + new_access_key + + + fs.tos.secret-access-key + new_secret_key + + + fs.tos.session-token + new_session_token + + + """ + + new_credentials = provider.get_credentials() + + assert new_credentials.access_key_id == "access_key" + assert new_credentials.access_key_secret == "secret_key" # noqa S105 + assert new_credentials.security_token == "session_token" # noqa S105