Skip to content

Commit

Permalink
Refactor StaticInspector to TypesDatabase
Browse files Browse the repository at this point in the history
Reflects the current implementation why more.
  • Loading branch information
lagru committed Oct 7, 2024
1 parent 848a1ff commit ef8de11
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 30 deletions.
10 changes: 5 additions & 5 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,18 +346,18 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self._stack.pop()


class StaticInspector:
"""Static analysis of Python packages.
class TypesDatabase:
"""A static database of collected types usable as an annotation.
Attributes
----------
current_source : ~.PackageFile | None
Examples
--------
>>> from docstub._analysis import StaticInspector, common_known_imports
>>> inspector = StaticInspector(known_imports=common_known_imports())
>>> inspector.query("Any")
>>> from docstub._analysis import TypesDatabase, common_known_imports
>>> db = TypesDatabase(known_imports=common_known_imports())
>>> db.query("Any")
('Any', <KnownImport 'from typing import Any'>)
"""

Expand Down
10 changes: 5 additions & 5 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from ._analysis import (
KnownImport,
StaticInspector,
TypeCollector,
TypesDatabase,
common_known_imports,
)
from ._cache import FileCache
Expand Down Expand Up @@ -145,12 +145,12 @@ def main(source_dir, out_dir, config_path, verbose):
config = _load_configuration(config_path)
known_imports = _build_import_map(config, source_dir)

inspector = StaticInspector(
types_db = TypesDatabase(
source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports
)
# and the stub transformer
stub_transformer = Py2StubTransformer(
inspector=inspector, replace_doctypes=config.replace_doctypes
types_db=types_db, replace_doctypes=config.replace_doctypes
)

if not out_dir:
Expand Down Expand Up @@ -182,14 +182,14 @@ def main(source_dir, out_dir, config_path, verbose):
fo.write(stub_content)

# Report basic statistics
successful_queries = inspector.stats["successful_queries"]
successful_queries = types_db.stats["successful_queries"]
click.secho(f"{successful_queries} matched annotations", fg="green")

grammar_errors = stub_transformer.transformer.stats["grammar_errors"]
if grammar_errors:
click.secho(f"{grammar_errors} grammar violations", fg="red")

unknown_doctypes = inspector.stats["unknown_doctypes"]
unknown_doctypes = types_db.stats["unknown_doctypes"]
if unknown_doctypes:
click.secho(f"{len(unknown_doctypes)} unknown doctypes:", fg="red")
click.echo(" " + "\n ".join(set(unknown_doctypes)))
Expand Down
17 changes: 8 additions & 9 deletions src/docstub/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,12 @@ class DoctypeTransformer(lark.visitors.Transformer):
[('tuple', 0, 5), ('int', 9, 12)]
"""

def __init__(self, *, inspector=None, replace_doctypes=None, **kwargs):
def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs):
"""
Parameters
----------
inspector : ~.StaticInspector
A dictionary mapping atomic names used in doctypes to information such
as where to import from or how to replace the name itself.
types_db : ~.TypesDatabase
A static database of collected types usable as an annotation.
replace_doctypes : dict[str, str], optional
Replacements for human-friendly aliases.
kwargs : dict[Any, Any], optional
Expand All @@ -170,7 +169,7 @@ def __init__(self, *, inspector=None, replace_doctypes=None, **kwargs):
if replace_doctypes is None:
replace_doctypes = {}

self.inspector = inspector
self.types_db = types_db
self.replace_doctypes = replace_doctypes

self._collected_imports = None
Expand Down Expand Up @@ -302,16 +301,16 @@ def contains(self, tree):
def literals(self, tree):
out = ", ".join(tree.children)
out = f"Literal[{out}]"
if self.inspector is not None:
_, known_import = self.inspector.query("Literal")
if self.types_db is not None:
_, known_import = self.types_db.query("Literal")
if known_import:
self._collected_imports.add(known_import)
return out

def _find_import(self, qualname, meta):
"""Match type names to known imports."""
if self.inspector is not None:
annotation_name, known_import = self.inspector.query(qualname)
if self.types_db is not None:
annotation_name, known_import = self.types_db.query(qualname)
else:
annotation_name = None
known_import = None
Expand Down
14 changes: 7 additions & 7 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class Py2StubTransformer(cst.CSTTransformer):
Attributes
----------
inspector : ~._analysis.StaticInspector
types_db : ~.TypesDatabase
"""

METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,)
Expand All @@ -198,17 +198,17 @@ class Py2StubTransformer(cst.CSTTransformer):
_Annotation_Any = cst.Annotation(cst.Name("Any"))
_Annotation_None = cst.Annotation(cst.Name("None"))

def __init__(self, *, inspector=None, replace_doctypes=None):
def __init__(self, *, types_db=None, replace_doctypes=None):
"""
Parameters
----------
inspector : ~._analysis.StaticInspector
types_db : ~.TypesDatabase
replace_doctypes : dict[str, str]
"""
self.inspector = inspector
self.types_db = types_db
self.replace_doctypes = replace_doctypes
self.transformer = DoctypeTransformer(
inspector=inspector, replace_doctypes=replace_doctypes
types_db=types_db, replace_doctypes=replace_doctypes
)
# Relevant docstring for the current context
self._scope_stack = None # Entered module, class or function scopes
Expand All @@ -225,8 +225,8 @@ def current_source(self):
@current_source.setter
def current_source(self, value):
self._current_source = value
if self.inspector is not None:
self.inspector.current_source = value
if self.types_db is not None:
self.types_db.current_source = value

def python_to_stub(self, source, *, module_path=None):
"""Convert Python source code to stub-file ready code.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

from docstub._analysis import KnownImport, StaticInspector
from docstub._analysis import KnownImport, TypesDatabase


class Test_StaticInspector:
class Test_TypesDatabase:
known_imports = { # noqa: RUF012
"dict": KnownImport(builtin_name="dict"),
"np": KnownImport(import_name="numpy", import_alias="np"),
Expand Down Expand Up @@ -48,9 +48,9 @@ class Test_StaticInspector:
]
)
def test_query(self, name, exp_annotation, exp_import_line):
inspector = StaticInspector(known_imports=self.known_imports.copy())
db = TypesDatabase(known_imports=self.known_imports.copy())

annotation, known_import = inspector.query(name)
annotation, known_import = db.query(name)

if exp_annotation is None and exp_import_line is None:
assert exp_annotation is annotation
Expand Down

0 comments on commit ef8de11

Please sign in to comment.