diff --git a/setup.py b/setup.py index 6b2c47f..4bbd6d6 100644 --- a/setup.py +++ b/setup.py @@ -16,10 +16,11 @@ # under the License. """The script for setting up stmdency.""" +from __future__ import annotations + import logging import os from distutils.dir_util import remove_tree -from typing import List from setuptools import Command, setup @@ -30,7 +31,7 @@ class CleanCommand(Command): """Command to clean up python api before setup by running `python setup.py pre_clean`.""" description = "Clean up project root" - user_options: List[str] = [] + user_options: list[str] = [] clean_list = [ "build", "htmlcov", diff --git a/src/stmdency/visitors/base.py b/src/stmdency/visitors/base.py index 4ece885..22507fd 100644 --- a/src/stmdency/visitors/base.py +++ b/src/stmdency/visitors/base.py @@ -4,7 +4,7 @@ import libcst as cst import libcst.matchers as m -from libcst import Assign, FunctionDef, Import, ImportFrom +from libcst import Assign, ClassDef, FunctionDef, Import, ImportFrom from stmdency.models.node import StmdencyNode from stmdency.visitors.assign import AssignVisitor @@ -22,7 +22,7 @@ class BaseVisitor(cst.CSTVisitor): stack: dict[str, StmdencyNode] = field(default_factory=dict) # Add scope to determine if the node is in the same scope - scope: dict[cst.CSTNode] = field(default_factory=set) + scope: set[cst.CSTNode] = field(default_factory=set) def handle_import(self, node: Import | ImportFrom) -> None: """Handle `import` / `from xx import xxx` statement and parse/add to stack.""" @@ -44,6 +44,19 @@ def visit_ImportFrom(self, node: ImportFrom) -> bool | None: self.handle_import(node) return True + def visit_ClassDef(self, node: ClassDef) -> bool | None: + """Handle class definition, pass to ClassDefVisitor and add scope. + + the reason add scope is to skip the visit_Assign in current class + """ + self.scope.add(node) + self.stack.update([(node.name.value, StmdencyNode(node=node))]) + return True + + def leave_ClassDef(self, original_node: ClassDef) -> None: + """Remove class definition in scope.""" + self.scope.remove(original_node) + def visit_FunctionDef(self, node: FunctionDef) -> bool | None: """Handle function definition, pass to FunctionVisitor and add scope. diff --git a/tests/extractor/test_assign.py b/tests/extractor/test_assign.py index f6777cb..28590bc 100644 --- a/tests/extractor/test_assign.py +++ b/tests/extractor/test_assign.py @@ -1,4 +1,4 @@ -from typing import Dict +from __future__ import annotations import pytest @@ -81,5 +81,5 @@ @pytest.mark.parametrize("name, source, expects", assign_cases) -def test_assign(name: str, source: str, expects: Dict[str, str]) -> None: +def test_assign(name: str, source: str, expects: dict[str, str]) -> None: assert_extract(name, source, expects) diff --git a/tests/extractor/test_function.py b/tests/extractor/test_function.py index 147e082..bf19d9f 100644 --- a/tests/extractor/test_function.py +++ b/tests/extractor/test_function.py @@ -1,4 +1,4 @@ -from typing import Dict +from __future__ import annotations import pytest @@ -296,5 +296,5 @@ def foo(): @pytest.mark.parametrize("name, source, expects", func_cases) -def test_func(name: str, source: str, expects: Dict[str, str]) -> None: +def test_func(name: str, source: str, expects: dict[str, str]) -> None: assert_extract(name, source, expects) diff --git a/tests/extractor/test_import.py b/tests/extractor/test_import.py index e8a7b1f..7dd7db3 100644 --- a/tests/extractor/test_import.py +++ b/tests/extractor/test_import.py @@ -1,4 +1,4 @@ -from typing import Dict +from __future__ import annotations import pytest @@ -124,5 +124,5 @@ def bar(): @pytest.mark.parametrize("name, source, expects", import_cases) -def test_import(name: str, source: str, expects: Dict[str, str]) -> None: +def test_import(name: str, source: str, expects: dict[str, str]) -> None: assert_extract(name, source, expects) diff --git a/tests/extractor/test_module_class.py b/tests/extractor/test_module_class.py new file mode 100644 index 0000000..e3f40d4 --- /dev/null +++ b/tests/extractor/test_module_class.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import pytest + +from tests.testing import assert_extract + +import_cases = [ + ( + "simple class with init", + """ + class Foo: + def __init__(self, arg1): + self.arg1 = arg1 + def foo(): + f = Foo(arg1=1) + print(f) + """, + { + "foo": """\ + class Foo: + def __init__(self, arg1): + self.arg1 = arg1\n\n + def foo(): + f = Foo(arg1=1) + print(f) + """, + }, + ), +] + + +@pytest.mark.parametrize("name, source, expects", import_cases) +def test_import(name: str, source: str, expects: dict[str, str]) -> None: + assert_extract(name, source, expects) diff --git a/tests/testing.py b/tests/testing.py index df2e250..3ac4096 100644 --- a/tests/testing.py +++ b/tests/testing.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import textwrap -from typing import Dict from stmdency.extractor import Extractor -def assert_extract(name: str, source: str, expects: Dict[str, str]) -> None: +def assert_extract(name: str, source: str, expects: dict[str, str]) -> None: wrap_source = textwrap.dedent(source) extractor = Extractor(source=wrap_source) for expect in expects: