Skip to content

Commit

Permalink
Add return type annotations to functions that always return None
Browse files Browse the repository at this point in the history
I don't see a way to make mypy require return type annotations only when
there are non-bare return statements in the function.
  • Loading branch information
dseomn committed Sep 24, 2023
1 parent 63963b7 commit 6f67570
Show file tree
Hide file tree
Showing 20 changed files with 80 additions and 78 deletions.
4 changes: 2 additions & 2 deletions rock_paper_sand/config_subcommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class Lint(subcommand.Subcommand):
"""Lints the config file."""

def run(self, args: argparse.Namespace):
def run(self, args: argparse.Namespace) -> None:
"""See base class."""
del args # Unused.
with network.null_requests_session() as session:
Expand All @@ -49,7 +49,7 @@ def run(self, args: argparse.Namespace):
class Main(subcommand.ContainerSubcommand):
"""Main config command."""

def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
subparsers = parser.add_subparsers()
Expand Down
4 changes: 2 additions & 2 deletions rock_paper_sand/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_lint(
*,
config_data: Any,
expected_results: Any,
):
) -> None:
self.enter_context(
flagsaver.flagsaver(
(
Expand All @@ -143,7 +143,7 @@ def test_lint(

self.assertEqual(expected_results, results)

def test_example_config(self):
def test_example_config(self) -> None:
with flagsaver.flagsaver(
(
flags_and_constants.CONFIG_FILE,
Expand Down
4 changes: 2 additions & 2 deletions rock_paper_sand/flags_and_constants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def test_get_app_dir(
self,
env: Mapping[str, str],
expected_path: pathlib.Path,
):
) -> None:
with mock.patch.dict(os.environ, env, clear=True):
actual_path = flags_and_constants._get_app_dir(
xdg_variable_name="XDG_FOO_HOME",
relative_fallback_path=pathlib.Path("foo"),
)
self.assertEqual(expected_path, actual_path)

def test_get_app_dir_error(self):
def test_get_app_dir_error(self) -> None:
with self.assertRaisesRegex(ValueError, "No HOME directory"):
with mock.patch.dict(os.environ, {}, clear=True):
flags_and_constants._get_app_dir(
Expand Down
6 changes: 3 additions & 3 deletions rock_paper_sand/justwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
*,
session: requests.Session,
base_url: str = _BASE_URL,
):
) -> None:
self._session = session
self._base_url = base_url
self._cache = {}
Expand Down Expand Up @@ -157,7 +157,7 @@ class _Availability:
default_factory=collections.Counter
)

def update(self, other: "_Availability"):
def update(self, other: "_Availability") -> None:
self.total_episode_count += other.total_episode_count
self.episode_count_by_offer.update(other.episode_count_by_offer)

Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(
filter_config: config_pb2.JustWatchFilter,
*,
api: Api,
):
) -> None:
self._config = filter_config
self._api = api

Expand Down
18 changes: 9 additions & 9 deletions rock_paper_sand/justwatch_subcommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from rock_paper_sand import subcommand


def _add_locale_arg(parser: argparse.ArgumentParser):
def _add_locale_arg(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--locale", help="JustWatch locale.", required=True)


class Locales(subcommand.Subcommand):
"""Prints the available JustWatch locales."""

def run(self, args: argparse.Namespace):
def run(self, args: argparse.Namespace) -> None:
"""See base class."""
del args # Unused.
with network.requests_session() as session:
Expand All @@ -43,12 +43,12 @@ def run(self, args: argparse.Namespace):
class MonetizationTypes(subcommand.Subcommand):
"""Prints the available JustWatch monetization types."""

def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
_add_locale_arg(parser)

def run(self, args: argparse.Namespace):
def run(self, args: argparse.Namespace) -> None:
"""See base class."""
with network.requests_session() as session:
api = justwatch.Api(session=session)
Expand All @@ -63,12 +63,12 @@ def run(self, args: argparse.Namespace):
class Providers(subcommand.Subcommand):
"""Prints the available JustWatch providers."""

def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
_add_locale_arg(parser)

def run(self, args: argparse.Namespace):
def run(self, args: argparse.Namespace) -> None:
"""See base class."""
with network.requests_session() as session:
api = justwatch.Api(session=session)
Expand All @@ -78,7 +78,7 @@ def run(self, args: argparse.Namespace):
class Search(subcommand.Subcommand):
"""Searches for a media item."""

def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
_add_locale_arg(parser)
Expand All @@ -90,7 +90,7 @@ def run(
*,
out_file: IO[str] = sys.stdout,
api: justwatch.Api | None = None,
):
) -> None:
"""See base class."""
with network.requests_session() as session:
if api is None:
Expand All @@ -117,7 +117,7 @@ def run(
class Main(subcommand.ContainerSubcommand):
"""Main JustWatch API command."""

def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
subparsers = parser.add_subparsers()
Expand Down
4 changes: 2 additions & 2 deletions rock_paper_sand/justwatch_subcommand_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@


class JustWatchSubcommandTest(parameterized.TestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self._mock_api = mock.create_autospec(
justwatch.Api, instance=True, spec_set=True
)

def test_search(self):
def test_search(self) -> None:
parser = argparse.ArgumentParser()
output = io.StringIO()
self._mock_api.post.return_value = {
Expand Down
26 changes: 13 additions & 13 deletions rock_paper_sand/justwatch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _offer(


class JustWatchApiTest(parameterized.TestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self._mock_session = mock.create_autospec(
requests.Session, spec_set=True, instance=True
Expand All @@ -60,7 +60,7 @@ def setUp(self):
)
self._mock_session.reset_mock()

def test_get(self):
def test_get(self) -> None:
self._mock_session.get.return_value.json.return_value = "foo"

data = self._api.get("bar")
Expand All @@ -75,7 +75,7 @@ def test_get(self):
self._mock_session.mock_calls,
)

def test_cache(self):
def test_cache(self) -> None:
self._mock_session.get.return_value.json.return_value = "foo"
self._api.get("bar")
self._mock_session.reset_mock()
Expand All @@ -85,7 +85,7 @@ def test_cache(self):
self.assertEqual("foo", data)
self.assertEmpty(self._mock_session.mock_calls)

def test_post(self):
def test_post(self) -> None:
self._mock_session.post.return_value.json.return_value = "foo"

data = self._api.post("bar", ["payload"])
Expand All @@ -100,7 +100,7 @@ def test_post(self):
self._mock_session.mock_calls,
)

def test_locales(self):
def test_locales(self) -> None:
self._mock_session.get.return_value.json.return_value = [
{"full_locale": "foo"},
{"full_locale": "bar"},
Expand All @@ -113,7 +113,7 @@ def test_locales(self):
f"{self._base_url}/locales/state"
)

def test_providers(self):
def test_providers(self) -> None:
self._mock_session.get.return_value.json.return_value = [
{"short_name": "foo", "clear_name": "Foo+"},
]
Expand All @@ -125,7 +125,7 @@ def test_providers(self):
f"{self._base_url}/providers/locale/en_US"
)

def test_providers_cached(self):
def test_providers_cached(self) -> None:
self._mock_session.get.return_value.json.return_value = [
{"short_name": "foo", "clear_name": "Foo+"},
]
Expand All @@ -137,7 +137,7 @@ def test_providers_cached(self):
self.assertEqual({"foo": "Foo+"}, providers)
self.assertEmpty(self._mock_session.mock_calls)

def test_provider_name(self):
def test_provider_name(self) -> None:
self._mock_session.get.return_value.json.return_value = [
{"short_name": "foo", "clear_name": "Foo+"},
]
Expand All @@ -146,11 +146,11 @@ def test_provider_name(self):

self.assertEqual("Foo+", provider_name)

def test_provider_name_not_found(self):
def test_provider_name_not_found(self) -> None:
self._mock_session.get.return_value.json.return_value = []
self.assertEqual("foo", self._api.provider_name("foo", locale="en_US"))

def test_monetization_types(self):
def test_monetization_types(self) -> None:
self._mock_session.get.return_value.json.return_value = [
{"monetization_types": ["foo"]},
{"monetization_types": ["bar", "quux"]},
Expand All @@ -166,7 +166,7 @@ def test_monetization_types(self):


class FilterTest(parameterized.TestCase):
def setUp(self):
def setUp(self) -> None:
self._mock_api = mock.create_autospec(
justwatch.Api, spec_set=True, instance=True
)
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_filter(
media_item: Any,
api_data: Mapping[str, Any] = immutabledict.immutabledict(),
expected_result: media_filter.FilterResult,
):
) -> None:
self._mock_api.get.side_effect = lambda relative_url: api_data[
relative_url
]
Expand All @@ -461,7 +461,7 @@ def test_filter(

self.assertEqual(expected_result, result)

def test_possible_unknown_placeholder_datetime(self):
def test_possible_unknown_placeholder_datetime(self) -> None:
self._mock_api.get.return_value = {
"offers": [
_offer(
Expand Down
4 changes: 2 additions & 2 deletions rock_paper_sand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class MainCommand(subcommand.ContainerSubcommand):
def __init__(self, parser: argparse.ArgumentParser):
def __init__(self, parser: argparse.ArgumentParser) -> None:
"""See base class."""
super().__init__(parser)
subparsers = parser.add_subparsers()
Expand All @@ -52,7 +52,7 @@ def __init__(self, parser: argparse.ArgumentParser):
)


def main():
def main() -> None:
parser = argparse_flags.ArgumentParser()
main_command = MainCommand(parser)
args = parser.parse_args()
Expand Down
16 changes: 9 additions & 7 deletions rock_paper_sand/media_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
class Not(Filter):
"""Inverts another filter."""

def __init__(self, child: Filter, /):
def __init__(self, child: Filter, /) -> None:
self._child = child

def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
Expand All @@ -63,7 +63,9 @@ def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
class BinaryLogic(Filter):
"""Binary logic filter, i.e., "and" and "or"."""

def __init__(self, *children: Filter, op: Callable[[Iterable[bool]], bool]):
def __init__(
self, *children: Filter, op: Callable[[Iterable[bool]], bool]
) -> None:
self._children = children
self._op = op

Expand All @@ -83,7 +85,7 @@ def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
class HasParts(Filter):
"""Matches based on whether there are any child parts."""

def __init__(self, has_parts: bool):
def __init__(self, has_parts: bool) -> None:
self._has_parts = has_parts

def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
Expand All @@ -94,7 +96,7 @@ def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
class Done(Filter):
"""Matches based on the `done` field."""

def __init__(self, done: str):
def __init__(self, done: str) -> None:
self._done = multi_level_set.parse_number(done)

def filter(self, media_item: config_pb2.MediaItem) -> FilterResult:
Expand All @@ -112,7 +114,7 @@ def __init__(
self,
field_getter: Callable[[config_pb2.MediaItem], str],
matcher_config: config_pb2.StringFieldMatcher,
):
) -> None:
self._field_getter = field_getter
match matcher_config.WhichOneof("method"):
case "empty":
Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(
justwatch_factory: (
Callable[[config_pb2.JustWatchFilter], Filter] | None
) = None,
):
) -> None:
"""Initializer.
Args:
Expand All @@ -153,7 +155,7 @@ def __init__(
self._justwatch_factory = justwatch_factory
self._filter_by_name = {}

def register(self, name: str, filter_: Filter):
def register(self, name: str, filter_: Filter) -> None:
"""Registers a named filter."""
if name in self._filter_by_name:
raise ValueError(f"Filter {name!r} is defined multiple times.")
Expand Down
Loading

0 comments on commit 6f67570

Please sign in to comment.