diff --git a/tosfs/retry.py b/tosfs/retry.py index 7197f6d..fc0d5f2 100644 --- a/tosfs/retry.py +++ b/tosfs/retry.py @@ -33,14 +33,14 @@ from tosfs.exceptions import TosfsError -CONFLICT_CODE = "409" -TOO_MANY_REQUESTS_CODE = "429" -SERVICE_UNAVAILABLE = "503" +CONFLICT_CODE = 409 +TOO_MANY_REQUESTS_CODE = 429 +SERVICE_UNAVAILABLE = 503 TOS_SERVER_RETRYABLE_STATUS_CODES = { CONFLICT_CODE, TOO_MANY_REQUESTS_CODE, - "500", # INTERNAL_SERVER_ERROR, + 500, # INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, } @@ -111,13 +111,18 @@ def retryable_func_executor( except InterruptedError as ie: raise TosfsError(f"Request {func} interrupted.") from ie else: - raise e + _rethrow_retryable_exception(e) # Note: maybe not all the retryable exceptions are warped by `TosError` # Will pay attention to those cases except Exception as e: raise TosfsError(f"{e}") from e +def _rethrow_retryable_exception(e: TosError) -> None: + """For debug purpose.""" + raise e + + def is_retryable_exception(e: TosError) -> bool: """Check if the exception is retryable.""" return _is_retryable_tos_server_exception(e) or _is_retryable_tos_client_exception( diff --git a/tosfs/tests/test_retry.py b/tosfs/tests/test_retry.py new file mode 100644 index 0000000..ab13630 --- /dev/null +++ b/tosfs/tests/test_retry.py @@ -0,0 +1,82 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +import pytest +import requests +from tos.exceptions import TosClientError, TosServerError +from tos.http import Response +from urllib3.exceptions import ProtocolError + +from tosfs.retry import is_retryable_exception + +mock_resp = Mock(spec=requests.Response) + +mock_resp.status_code = 429 +mock_resp.headers = {"content-length": "123", "x-tos-request-id": "test-id"} +mock_resp.iter_content = Mock(return_value=[b"chunk1", b"chunk2", b"chunk3"]) +mock_resp.json = Mock(return_value={"key": "value"}) + +response = Response(mock_resp) + + +@pytest.mark.parametrize( + ("exception", "expected"), + [ + ( + TosServerError( + response, + "Exceed account external rate limit. Too much throughput in a " + "short period of time, please slow down.", + "ExceedAccountExternalRateLimit", + "KmsJSKDKhjasdlKmsduwRETYHB", + "", + "0004-00000001", + ), + True, + ), + ( + TosClientError( + "http request timeout", + ConnectionError( + ProtocolError( + "Connection aborted.", + ConnectionResetError(104, "Connection reset by peer"), + ) + ), + ), + True, + ), + ], +) +def test_is_retry_exception( + exception, + expected, +): + assert is_retryable_exception(exception) == expected