diff --git a/.gitignore b/.gitignore index 7473c86b..cf4fafd1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,5 @@ *.pyc -.cache -.pytest_cache -.ipynb_checkpoints -build -docs/_build hatchet/cython_modules/libs/graphframe_modules.*.so hatchet/cython_modules/libs/reader_modules.*.so hatchet/cython_modules/*.c @@ -12,5 +7,163 @@ hatchet/vis/*node_modules* hatchet/vis/static/*_bundle* *package-lock.json +############################################### +# Everything from here on comes from the GitHub +# gitignore template for Python projects +############################################### + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ dist/ -llnl_hatchet.egg-info/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/hatchet/external/__init__.py b/hatchet/external/__init__.py index 166a47b7..f06f71b2 100644 --- a/hatchet/external/__init__.py +++ b/hatchet/external/__init__.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from typing import TYPE_CHECKING + class VersionError(Exception): """ @@ -13,23 +15,24 @@ class VersionError(Exception): pass -try: - import IPython +if not TYPE_CHECKING: + try: + import IPython - # Testing IPython version - if int(IPython.__version__.split(".")[0]) > 7: - raise VersionError() + # Testing IPython version + if int(IPython.__version__.split(".")[0]) > 7: + raise VersionError() - from .roundtrip.roundtrip.manager import Roundtrip + from .roundtrip.roundtrip.manager import Roundtrip - # Refrencing Roundtrip here to resolve scope issues with import - Roundtrip + # Refrencing Roundtrip here to resolve scope issues with import + Roundtrip -except ImportError: - pass + except ImportError: + pass -except VersionError: - if IPython.get_ipython() is not None: - print( - "Warning: Roundtrip module could not be loaded. Requires jupyter notebook version <= 7.x." - ) + except VersionError: + if IPython.get_ipython() is not None: + print( + "Warning: Roundtrip module could not be loaded. Requires jupyter notebook version <= 7.x." + ) diff --git a/hatchet/external/console.py b/hatchet/external/console.py index 1de5e145..6b7de115 100644 --- a/hatchet/external/console.py +++ b/hatchet/external/console.py @@ -34,7 +34,7 @@ import pandas as pd import numpy as np import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from ..util.colormaps import ColorMaps from ..node import Node @@ -43,14 +43,19 @@ class ConsoleRenderer: def __init__(self, unicode: bool = False, color: bool = False) -> None: self.unicode = unicode self.color = color - self.visited = [] + self.visited: List[Node] = [] + self.colors_annotations_mapping: Optional[Union[List, Dict[str, Any]]] = None + self.colors: Optional[ + Union["ConsoleRenderer.colors_enabled", "ConsoleRenderer.colors_disabled"] + ] = None + self.temporal_symbols: Dict[str, str] = {} def render( self, roots: Optional[Union[List[Node], Tuple[Node, ...]]], dataframe: pd.DataFrame, **kwargs, - ) -> str: + ) -> Union[str, bytes]: self.render_header = kwargs["render_header"] if self.render_header: @@ -79,7 +84,7 @@ def render( self.max_value = kwargs["max_value"] if self.color: - self.colors = self.colors_enabled + self.colors = self.colors_enabled() # set the colormap based on user input self.colors.colormap = ColorMaps().get_colors( self.colormap, self.invert_colormap @@ -100,7 +105,7 @@ def render( elif isinstance(self.colormap_annotations, dict): self.colors_annotations_mapping = self.colormap_annotations else: - self.colors = self.colors_disabled + self.colors = self.colors_disabled() if isinstance(self.metric_columns, (str, tuple)): self.primary_metric = self.metric_columns @@ -265,14 +270,13 @@ def render_frame( if node_depth < self.depth: # set dataframe index based on whether rank and thread are part of # the MultiIndex + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] = node if "rank" in dataframe.index.names and "thread" in dataframe.index.names: df_index = (node, self.rank, self.thread) elif "rank" in dataframe.index.names: df_index = (node, self.rank) elif "thread" in dataframe.index.names: df_index = (node, self.thread) - else: - df_index = node node_metric = dataframe.loc[df_index, self.primary_metric] @@ -326,10 +330,12 @@ def render_frame( # no pattern column elif self.colormap_annotations: if isinstance(self.colormap_annotations, dict): + assert isinstance(self.colors_annotations_mapping, Dict) color_annotation = self.colors_annotations_mapping[ annotation_content ] else: + assert isinstance(self.colors_annotations_mapping, List) color_annotation = self.colors_annotations.colormap[ self.colors_annotations_mapping.index(annotation_content) % len(self.colors_annotations.colormap) @@ -441,7 +447,7 @@ def _ansi_color_for_name(self, node_name: str) -> str: return self.colors.bg_white_255 + self.colors.dark_gray_255 class colors_enabled: - colormap = [] + colormap: List[str] = [] blue = "\033[34m" cyan = "\033[36m" @@ -456,9 +462,7 @@ class colors_enabled: end = "\033[0m" class colors_disabled: - colormap = ["", "", "", "", "", "", ""] + colormap: List[str] = ["", "", "", "", "", "", ""] def __getattr__(self, key: str) -> str: return "" - - colors_disabled = colors_disabled() diff --git a/hatchet/frame.py b/hatchet/frame.py index 4462640c..affd0708 100644 --- a/hatchet/frame.py +++ b/hatchet/frame.py @@ -47,9 +47,11 @@ def __init__(self, attrs: Optional[Dict[str, Any]] = None, **kwargs) -> None: if "type" not in self.attrs: self.attrs["type"] = "None" - self._tuple_repr = None + self._tuple_repr: Optional[Tuple[Tuple[str, Any], ...]] = None - def __eq__(self, other: "Frame") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Frame): + return NotImplemented return self.tuple_repr == other.tuple_repr def __lt__(self, other: "Frame") -> bool: @@ -69,7 +71,7 @@ def __repr__(self) -> str: return "Frame(%s)" % self @property - def tuple_repr(self) -> Tuple[Tuple[str, Any], ...]: + def tuple_repr(self) -> Optional[Tuple[Tuple[str, Any], ...]]: """Make a tuple of attributes and values based on reader.""" if not self._tuple_repr: self._tuple_repr = tuple(sorted((k, v) for k, v in self.attrs.items())) diff --git a/hatchet/graph.py b/hatchet/graph.py index 78a1fae5..91524d29 100644 --- a/hatchet/graph.py +++ b/hatchet/graph.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from .node import Node, traversal_order, node_traversal_order @@ -86,7 +86,7 @@ def is_tree(self) -> bool: if len(self.roots) > 1: return False - visited = {} + visited: Dict[int, int] = {} list(self.traverse(visited=visited)) return all(v == 1 for v in visited.values()) @@ -101,7 +101,7 @@ def find_merges(self) -> Dict[Node, Node]: (dict): dictionary from nodes to their merge targets """ - merges = {} # old_node -> merged_node + merges: Dict[Node, Node] = {} # old_node -> merged_node inverted_merges = defaultdict( lambda: [] ) # merged_node -> list of corresponding old_nodes @@ -125,6 +125,7 @@ def _find_child_merges(node_list): _find_child_merges(self.roots) for node in self.traverse(): + assert isinstance(node, Node) if node in processed: continue nodes = None @@ -189,6 +190,7 @@ def copy(self, old_to_new: Optional[Dict[Node, Node]] = None) -> "Graph": # first pass creates new nodes for node in self.traverse(): + assert isinstance(node, Node) old_to_new[node] = node.copy() # second pass hooks up parents and children @@ -205,7 +207,7 @@ def copy(self, old_to_new: Optional[Dict[Node, Node]] = None) -> "Graph": return graph def union( - self, other: "Graph", old_to_new: Optional[Dict[Node, Node]] = None + self, other: "Graph", old_to_new: Optional[Dict[int, Node]] = None ) -> "Graph": """Create the union of self and other and return it as a new Graph. @@ -365,7 +367,7 @@ def _iter_depth(node, visited): child._depth = node._depth + 1 _iter_depth(child, visited) - visited = set() + visited: Set[Node] = set() for root in self.roots: root._depth = 0 # depth of root node is 0 _iter_depth(root, visited) @@ -375,9 +377,11 @@ def enumerate_traverse(self) -> None: # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: for i, node in enumerate(self.node_order_traverse()): + assert isinstance(node, Node) node._hatchet_nid = i else: for i, node in enumerate(self.traverse()): + assert isinstance(node, Node) node._hatchet_nid = i self.enumerate_depth() @@ -386,23 +390,28 @@ def _check_enumerate_traverse(self) -> bool: # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: for i, node in enumerate(self.node_order_traverse()): + assert isinstance(node, Node) if i != node._hatchet_nid: return False else: for i, node in enumerate(self.traverse()): + assert isinstance(node, Node) if i != node._hatchet_nid: return False + return True def __len__(self) -> int: """Size of the graph in terms of number of nodes.""" return sum(1 for _ in self.traverse()) - def __eq__(self, other: "Graph") -> bool: + def __eq__(self, other: object) -> bool: """Check if two graphs have the same structure by comparing frame at each node. """ - vs = set() - vo = set() + if not isinstance(other, Graph): + return NotImplemented + vs: Set[int] = set() + vo: Set[int] = set() # if both graphs are pointing to the same object, then graphs are equal if self is other: @@ -429,7 +438,9 @@ def __eq__(self, other: "Graph") -> bool: return True - def __ne__(self, other: "Graph") -> bool: + def __ne__(self, other: object) -> bool: + if not isinstance(other, Graph): + return NotImplemented return not (self == other) @staticmethod diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index f350d465..5e1ff5e3 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,8 +8,8 @@ import sys import traceback from collections import defaultdict -from collections.abc import Callable -from typing import Any, Dict, List, Optional, Tuple, Union +from collections.abc import Callable, Iterable +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from io import TextIOWrapper import multiprocess as mp @@ -33,7 +33,7 @@ from .util.dot import trees_to_dot try: - from .cython_modules.libs import graphframe_modules as _gfm_cy + from .cython_modules.libs import graphframe_modules as _gfm_cy # type: ignore except ImportError: print("-" * 80) print( @@ -134,7 +134,7 @@ def from_hpctoolkit_latest( max_depth: Optional[int] = None, min_percentage_of_application_time: Optional[int] = None, min_percentage_of_parent_time: Optional[int] = None, - ) -> "GraphFrame": + ) -> Optional["GraphFrame"]: """ Read an HPCToolkit database directory into a new GraphFrame @@ -202,7 +202,7 @@ def from_timeseries( level: str = "loop.start_iteration", native: bool = False, string_attributes: Union[List[str], str] = [], - ) -> "GraphFrame": + ) -> List["GraphFrame"]: """Read in a native Caliper timeseries `cali` file using Caliper's python reader. Args: @@ -220,7 +220,9 @@ def from_timeseries( ).read_timeseries(level=level) @staticmethod - def from_spotdb(db_key: Any, list_of_ids: Optional[List] = None) -> "GraphFrame": + def from_spotdb( + db_key: Any, list_of_ids: Optional[List] = None + ) -> List["GraphFrame"]: """Read multiple graph frames from a SpotDB instance Args: @@ -280,7 +282,7 @@ def from_timemory( input: Optional[Union[str, TextIOWrapper, Dict[str, Any]]] = None, select: Optional[List[str]] = None, **_kwargs, - ) -> "GraphFrame": + ) -> Optional["GraphFrame"]: """Read in timemory data. Links: @@ -354,14 +356,17 @@ def from_timemory( pass else: try: - import timemory + import timemory # type: ignore[import-not-found] - TimemoryReader(timemory.get(hierarchy=True), select, **_kwargs).read() + return TimemoryReader( + timemory.get(hierarchy=True), select, **_kwargs + ).read() except ImportError: print( "Error! timemory could not be imported. Provide filename, file stream, or dict." ) raise + return None @staticmethod def from_literal(graph_dict: List[Dict]) -> "GraphFrame": @@ -385,7 +390,11 @@ def from_lists(*lists) -> "GraphFrame": df = pd.DataFrame({"node": list(graph.traverse())}) df["time"] = [1.0] * len(graph) - df["name"] = [n.frame["name"] for n in graph.traverse()] + name_col = [] + for n in graph.traverse(): + assert isinstance(n, Node) + name_col.append(n.frame["name"]) + df["name"] = name_col df.set_index(["node"], inplace=True) df.sort_index(inplace=True) @@ -406,9 +415,7 @@ def from_hdf(filename: str, **kwargs) -> "GraphFrame": return HDF5Reader(filename).read(**kwargs) - def to_hdf( - self, filename: str, key: str = "hatchet_graphframe", **kwargs - ) -> "GraphFrame": + def to_hdf(self, filename: str, key: str = "hatchet_graphframe", **kwargs) -> None: # import this lazily to avoid circular dependencies from .writers.hdf5_writer import HDF5Writer @@ -455,7 +462,7 @@ def deepcopy(self) -> "GraphFrame": default_metric (str): N/A metadata (dict): Copy of self's metadata """ - node_clone = {} + node_clone: Dict[Node, Node] = {} graph_copy = self.graph.copy(node_clone) dataframe_copy = self.dataframe.copy() @@ -562,7 +569,7 @@ def filter( elif isinstance(filter_obj, (list, str)) or is_hatchet_query(filter_obj): # use a callpath query to apply the filter - query = filter_obj + query: Union[Query, CompoundQuery] # If a raw Object-dialect query is provided (not already passed to ObjectQuery), # create a new ObjectQuery object. if isinstance(filter_obj, list): @@ -573,7 +580,10 @@ def filter( query = parse_string_dialect(filter_obj, multi_index_mode) # If an old-style query is provided, extract the underlying new-style query. elif issubclass(type(filter_obj), AbstractQuery): - query = filter_obj._get_new_query() + query = cast(AbstractQuery, filter_obj)._get_new_query() + else: + assert isinstance(filter_obj, (Query, CompoundQuery)) + query = filter_obj query_matches = self.query_engine.apply(query, self.graph, self.dataframe) # match_set = list(set().union(*query_matches)) # filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(match_set)] @@ -619,7 +629,7 @@ def squash(self, update_inc_cols: bool = True) -> "GraphFrame": # Maintain sets of connections to make for each old node. # Start with old -> new mapping and update as we traverse subgraphs. - connections = defaultdict(lambda: set()) + connections: Dict[Node, Set[Node]] = defaultdict(lambda: set()) connections.update({k: {v} for k, v in old_to_new.items()}) new_roots = [] # list of new roots @@ -657,7 +667,7 @@ def rewire(node, new_parent, visited): return connections[node] # run rewire for each root and make a new graph - visited = set() + visited: Set[Node] = set() for root in self.graph.roots: rewire(root, None, visited) graph = Graph(new_roots) @@ -750,7 +760,8 @@ def subtree_sum( out_columns = self._init_sum_columns(columns, out_columns) # sum over the output columns - for node in self.graph.traverse(order="post"): + for trav_node in self.graph.traverse(order="post"): + node = cast(Node, trav_node) if node.children: # TODO: need a better way of aggregating inclusive metrics when # TODO: there is a multi-index @@ -815,7 +826,8 @@ def subgraph_sum( return out_columns = self._init_sum_columns(columns, out_columns) - for node in self.graph.traverse(): + for trav_node in self.graph.traverse(): + node = cast(Node, trav_node) subgraph_nodes = list(node.traverse()) # TODO: need a better way of aggregating inclusive metrics when # TODO: there is a multi-index @@ -893,13 +905,15 @@ def generate_exclusive_columns( # suffix) to the generation list. else: generation_pairs.append((inc + " (exc)", inc)) + node: Node # Consider each new exclusive metric and its corresponding inclusive metric for exc, inc in generation_pairs: # Process of obtaining inclusive data for a node differs if the DataFrame has an Index vs a MultiIndex if isinstance(self.dataframe.index, pd.MultiIndex): - new_data = {} + new_data: Dict[Union[Tuple[Any, ...], Node], int] = {} # Traverse every node in the Graph - for node in self.graph.traverse(): + for trav_node in self.graph.traverse(): + node = cast(Node, trav_node) # Consider each unique portion of the MultiIndex corresponding to the current node for non_node_idx in self.dataframe.loc[(node)].index.unique(): # If there's only 1 index level besides "node", add it to a 1-element list to ensure consistent typing @@ -930,7 +944,7 @@ def generate_exclusive_columns( # Create a basic Node-metric dict for the new exclusive metric new_data = {n: -1 for n in self.dataframe.index.values} # Traverse the graph - for node in self.graph.traverse(): + for node in cast(Iterable[Node], self.graph.traverse()): # Sum up the inclusive metric values of the current node's children inc_sum = 0 for child in node.children: @@ -994,7 +1008,7 @@ def unify(self, other: "GraphFrame"): if self.graph is other.graph: return - node_map = {} + node_map: Dict[int, Node] = {} union_graph = self.graph.union(other.graph, node_map) self_index_names = self.dataframe.index.names @@ -1043,7 +1057,7 @@ def tree( render_header: bool = True, min_value: Optional[int] = None, max_value: Optional[int] = None, - ) -> str: + ) -> Union[str, bytes]: """Visualize the Hatchet graphframe as a tree Arguments: @@ -1122,8 +1136,9 @@ def to_dot( """ if metric is None: metric = self.default_metric + graph_roots = cast(List[Node], self.graph.roots) return trees_to_dot( - self.graph.roots, self.dataframe, metric, name, rank, thread, threshold + graph_roots, self.dataframe, metric, name, rank, thread, threshold ) def to_flamegraph( @@ -1142,9 +1157,10 @@ def to_flamegraph( metric = self.default_metric for root in self.graph.roots: - for hnode in root.traverse(): + for hnode in cast(Iterable[Node], root.traverse()): callpath = hnode.path() for i in range(0, len(callpath) - 1): + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] if ( "rank" in self.dataframe.index.names and "thread" in self.dataframe.index.names @@ -1285,7 +1301,9 @@ def add_nodes(hnode): return graph_literal def to_dict(self) -> Dict: - hatchet_dict = {} + hatchet_dict: Dict[ + str, Union[List[Dict[int, Dict[str, Any]]], List[str], Dict] + ] = {} """ Nodes: {hatchet_nid: {node data, children:[by-id]}} @@ -1293,7 +1311,7 @@ def to_dict(self) -> Dict: graphs = [] for root in self.graph.roots: formatted_graph_dict = {} - for n in root.traverse(): + for n in cast(Iterable[Node], root.traverse()): formatted_graph_dict[n._hatchet_nid] = { "data": n.frame.attrs, "children": [c._hatchet_nid for c in n.children], @@ -1468,7 +1486,7 @@ def groupby_aggregate( """ # create new nodes for each unique node in the old dataframe # length is equal to number of nodes in original graph - old_to_new = {} + old_to_new: Dict[Node, Node] = {} # list of new roots new_roots = [] @@ -1537,7 +1555,7 @@ def reindex(node, parent, visited): old_to_new[i] = super_node # reindex graph by traversing old graph - visited = set() + visited: Set[Node] = set() for root in self.graph.roots: reindex(root, None, visited) diff --git a/hatchet/node.py b/hatchet/node.py index 101dca45..f1205642 100644 --- a/hatchet/node.py +++ b/hatchet/node.py @@ -10,7 +10,7 @@ from .frame import Frame -def traversal_order(node: "Node") -> Tuple(Frame, int): +def traversal_order(node: "Node") -> Tuple[Frame, int]: """Deterministic key function for sorting nodes in traversals.""" return (node.frame, id(node)) @@ -36,10 +36,10 @@ def __init__( self._depth = depth self._hatchet_nid = hnid - self.parents = [] + self.parents: List["Node"] = [] if parent is not None: self.add_parent(parent) - self.children = [] + self.children: List["Node"] = [] def add_parent(self, node: "Node"): """Adds a parent to this node's list of parents.""" @@ -76,7 +76,7 @@ def path(self, attrs: Optional[Dict[str, Any]] = None) -> Tuple["Node", ...]: """ paths = self.paths() if len(paths) > 1: - raise MultiplePathError("Node has more than one path: " % paths) + raise MultiplePathError("Node has more than one path: " + str(paths)) return paths[0] def dag_equal( @@ -213,7 +213,9 @@ def value(node): def __hash__(self) -> int: return self._hatchet_nid - def __eq__(self, other: "Node") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Node): + return NotImplemented return self._hatchet_nid == other._hatchet_nid def __lt__(self, other: "Node") -> bool: @@ -233,10 +235,7 @@ def copy(self) -> "Node": @classmethod def from_lists( cls, - lists: Union[ - List[str, "Node", Union[List, Tuple]], - Tuple[str, "Node", Union[List, Tuple]], - ], + lists: Tuple[str, "Node", Union[List, Tuple]], ) -> "Node": r"""Construct a hierarchy of nodes from recursive lists. diff --git a/hatchet/query/__init__.py b/hatchet/query/__init__.py index 6e0da6a9..ba36e5db 100644 --- a/hatchet/query/__init__.py +++ b/hatchet/query/__init__.py @@ -6,7 +6,7 @@ # Make flake8 ignore unused names in this file # flake8: noqa: F401 -from typing import Any, Union, List, TypeVar +from typing import Any, Union, List from .query import Query from .compound import ( @@ -15,9 +15,10 @@ DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, + parse_string_dialect, ) from .object_dialect import ObjectQuery -from .string_dialect import StringQuery, parse_string_dialect +from .string_dialect import StringQuery from .engine import QueryEngine from .errors import ( InvalidQueryPath, @@ -41,17 +42,15 @@ parse_cypher_query, ) -BaseQueryType = TypeVar("BaseQuery", Query, ObjectQuery, StringQuery, str, List) -CompoundQueryType = TypeVar( - "CompoundQuery", +BaseQueryType = Union[Query, ObjectQuery, StringQuery, str, List] +CompoundQueryType = Union[ CompoundQuery, ConjunctionQuery, DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, -) -LegacyQueryType = TypeVar( - "LegacyQuery", +] +LegacyQueryType = Union[ AbstractQuery, NaryQuery, AndQuery, @@ -63,8 +62,7 @@ NotQuery, QueryMatcher, CypherQuery, -) -AnyQueryType = TypeVar("AnyQuery", BaseQueryType, CompoundQueryType, LegacyQueryType) +] def combine_via_conjunction( @@ -92,16 +90,19 @@ def negate_query(query: Union[BaseQueryType, CompoundQueryType]) -> NegationQuer return NegationQuery(query) -Query.__and__ = combine_via_conjunction -Query.__or__ = combine_via_disjunction -Query.__xor__ = combine_via_exclusive_disjunction -Query.__not__ = negate_query +# Note: skipping mypy checks here because we're monkey +# patching these operators. Per mypy Issue #2427, +# mypy doesn't like this +Query.__and__ = combine_via_conjunction # type: ignore +Query.__or__ = combine_via_disjunction # type: ignore +Query.__xor__ = combine_via_exclusive_disjunction # type: ignore +Query.__not__ = negate_query # type: ignore -CompoundQuery.__and__ = combine_via_conjunction -CompoundQuery.__or__ = combine_via_disjunction -CompoundQuery.__xor__ = combine_via_exclusive_disjunction -CompoundQuery.__not__ = negate_query +CompoundQuery.__and__ = combine_via_conjunction # type: ignore +CompoundQuery.__or__ = combine_via_disjunction # type: ignore +CompoundQuery.__xor__ = combine_via_exclusive_disjunction # type: ignore +CompoundQuery.__not__ = negate_query # type: ignore def is_hatchet_query(query_obj: Any) -> bool: diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index 7197ffc6..a2a96d6d 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -3,20 +3,13 @@ # # SPDX-License-Identifier: MIT -from abc import abstractmethod +from abc import abstractmethod, ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) import sys import warnings from collections.abc import Callable -from typing import List, Optional, Union +from typing import List, Optional, Union, cast, TYPE_CHECKING -from ..graphframe import GraphFrame from ..node import Node from .query import Query from .compound import ( @@ -25,12 +18,15 @@ DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, + parse_string_dialect, ) from .object_dialect import ObjectQuery -from .string_dialect import parse_string_dialect from .engine import QueryEngine from .errors import BadNumberNaryQueryArgs, InvalidQueryPath +if TYPE_CHECKING: + from ..graphframe import GraphFrame + # QueryEngine object for running the legacy "apply" methods COMPATABILITY_ENGINE: QueryEngine = QueryEngine() @@ -40,7 +36,7 @@ class AbstractQuery(ABC): """Base class for all 'old-style' queries.""" @abstractmethod - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: pass def __and__(self, other: "AbstractQuery") -> "AndQuery": @@ -76,7 +72,7 @@ def __xor__(self, other: "AbstractQuery") -> "XorQuery": """ return XorQuery(self, other) - def __invert__(self) -> "NegationQuery": + def __invert__(self) -> "NotQuery": """Create a new NotQuery using this query. Returns: @@ -99,7 +95,9 @@ def __init__(self, *args) -> None: Arguments: *args (AbstractQuery, str, or list): the subqueries to be performed """ - self.compat_subqueries = [] + self.compat_subqueries: List[ + Union[QueryMatcher, CypherQuery, AbstractQuery, Query, CompoundQuery] + ] = [] if isinstance(args[0], tuple) and len(args) == 1: args = args[0] for query in args: @@ -119,7 +117,7 @@ def __init__(self, *args) -> None: high-level query or a subclass of AbstractQuery" ) - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: """Applies the query to the specified GraphFrame. Arguments: @@ -139,9 +137,10 @@ def _get_new_query(self) -> Union[Query, CompoundQuery]: """ true_subqueries = [] for subq in self.compat_subqueries: - true_subq = subq if issubclass(type(subq), AbstractQuery): - true_subq = subq._get_new_query() + true_subq = cast(AbstractQuery, subq)._get_new_query() + else: + true_subq = cast(Union[Query, CompoundQuery], subq) true_subqueries.append(true_subq) return self._convert_to_new_query(true_subqueries) @@ -291,7 +290,7 @@ def __init__(self, query: Optional[Union[List, Query]] = None) -> None: DeprecationWarning, stacklevel=2, ) - self.true_query = None + self.true_query: Optional[Union[Query, CompoundQuery]] = None if query is None: self.true_query = Query() elif isinstance(query, list): @@ -314,6 +313,7 @@ def match( Returns: (QueryMatcher): the instance of the class that called this function """ + assert isinstance(self.true_query, Query) self.true_query.match(wildcard_spec, filter_func) return self @@ -332,10 +332,11 @@ def rel( Returns: (QueryMatcher): the instance of the class that called this function """ + assert isinstance(self.true_query, Query) self.true_query.rel(wildcard_spec, filter_func) return self - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: """Apply the query to a GraphFrame. Arguments: diff --git a/hatchet/query/compound.py b/hatchet/query/compound.py index 64c011f6..dc35fd17 100644 --- a/hatchet/query/compound.py +++ b/hatchet/query/compound.py @@ -6,12 +6,13 @@ from abc import abstractmethod import sys -from typing import List +import re +from typing import List, Optional, Set, Union, cast from ..node import Node from ..graph import Graph from .query import Query -from .string_dialect import parse_string_dialect +from .string_dialect import StringQuery from .object_dialect import ObjectQuery from .errors import BadNumberNaryQueryArgs @@ -153,7 +154,7 @@ def _apply_op_to_results( Returns: (list): A list containing all the nodes satisfying the exclusive disjunction of the subqueries' results """ - xor_set = set() + xor_set: Set[Node] = set() for res in subquery_results: xor_set = xor_set.symmetric_difference(set(res)) return list(xor_set) @@ -189,6 +190,174 @@ def _apply_op_to_results( Returns: (list): A list containing all the nodes in the Graph not contained in the subquery's results """ - nodes = set(graph.traverse()) + trav_nodes = set(graph.traverse()) + nodes = cast(Set[Node], trav_nodes) query_nodes = set(subquery_results[0]) return list(nodes.difference(query_nodes)) + + +def parse_string_dialect( + query_str: str, multi_index_mode: str = "off" +) -> Union[StringQuery, CompoundQuery]: + """Parse all types of String-based queries, including multi-queries that leverage + the curly brace delimiters. + + Arguments: + query_str (str): the String-based query to be parsed + + Returns: + (Query or CompoundQuery): A Hatchet query object representing the String-based query + """ + # TODO Check if there's a way to prevent curly braces in a string + # from being captured + + # Find the number of curly brace-delimited regions in the query + query_str = query_str.strip() + curly_brace_elems = re.findall(r"\{(.*?)\}", query_str) + num_curly_brace_elems = len(curly_brace_elems) + # If there are no curly brace-delimited regions, just pass the query + # off to the CypherQuery constructor + if num_curly_brace_elems == 0: + if sys.version_info[0] == 2: + query_str = query_str.decode("utf-8") + return StringQuery(query_str, multi_index_mode) + # Create an iterator over the curly brace-delimited regions + curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) + # Will store curly brace-delimited regions in the WHERE clause + condition_list = None + # Will store curly brace-delimited regions that contain entire + # mid-level queries (MATCH clause and WHERE clause) + query_list = None + # If entire queries are in brace-delimited regions, store the indexes + # of the regions here so we don't consider brace-delimited regions + # within the already-captured region. + query_idxes = None + # Store which compound queries to apply to the curly brace-delimited regions + compound_ops = [] + for i, match in enumerate(curly_brace_iter): + # Get the substring within curly braces + substr = query_str[match.start() + 1 : match.end() - 1] + substr = substr.strip() + # If an entire query (MATCH + WHERE) is within curly braces, + # add the query to "query_list", and add the indexes corresponding + # to the query to "query_idxes" + if substr.startswith("MATCH"): + if query_list is None: + query_list = [] + if query_idxes is None: + query_idxes = [] + query_list.append(substr) + query_idxes.append((match.start(), match.end())) + # If the curly brace-delimited region contains only parts of a + # WHERE clause, first, check if the region is within another + # curly brace delimited region. If it is, do nothing (it will + # be handled later). Otherwise, add the region to "condition_list" + elif re.match(r"[a-zA-Z0-9_]+\..*", substr) is not None: + is_encapsulated_region = False + if query_idxes is not None: + for s, e in query_idxes: + if match.start() >= s or match.end() <= e: + is_encapsulated_region = True + break + if is_encapsulated_region: + continue + if condition_list is None: + condition_list = [] + condition_list.append(substr) + # If the curly brace-delimited region is neither a whole query + # or part of a WHERE clause, raise an error + else: + raise ValueError("Invalid grouping (with curly braces) within the query") + # If there is a compound operator directly after the curly brace-delimited region, + # capture the type of operator, and store the type in "compound_ops" + if i + 1 < num_curly_brace_elems: + rest_substr = query_str[match.end() :] + rest_substr = rest_substr.strip() + if rest_substr.startswith("AND"): + compound_ops.append("AND") + elif rest_substr.startswith("OR"): + compound_ops.append("OR") + elif rest_substr.startswith("XOR"): + compound_ops.append("XOR") + else: + raise ValueError("Invalid compound operator type found!") + # Each call to this function should only consider one of the full query or + # WHERE clause versions at a time. If both types were captured, raise an error + # because some type of internal logic issue occured. + if condition_list is not None and query_list is not None: + raise ValueError( + "Curly braces must be around either a full mid-level query or a set of conditions in a single mid-level query" + ) + # This branch is for the WHERE clause version + if condition_list is not None: + # Make sure you correctly gathered curly brace-delimited regions and + # compound operators + if len(condition_list) != len(compound_ops) + 1: + raise ValueError( + "Incompatible number of curly brace elements and compound operators" + ) + # Get the MATCH clause that will be shared across the subqueries + match_comp_obj = re.search(r"MATCH\s+(?P.*)\s+WHERE", query_str) + match_comp = match_comp_obj.group("match_field") + # Iterate over the compound operators + full_query: Optional[Union[StringQuery, CompoundQuery]] = None + for i, op in enumerate(compound_ops): + # If in the first iteration, set the initial query as a CypherQuery where + # the MATCH clause is the shared match clause and the WHERE clause is the + # first curly brace-delimited region + if i == 0: + query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) + if sys.version_info[0] == 2: + query1 = query1.decode("utf-8") + full_query = StringQuery(query1, multi_index_mode) + # Get the next query as a CypherQuery where + # the MATCH clause is the shared match clause and the WHERE clause is the + # next curly brace-delimited region + next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) + if sys.version_info[0] == 2: + next_query = next_query.decode("utf-8") + next_string_query: Union[StringQuery, CompoundQuery] = StringQuery( + next_query, multi_index_mode + ) + # Add the next query to the full query using the compound operator + # currently being considered + if op == "AND": + assert full_query is not None + full_query = ConjunctionQuery(full_query, next_string_query) + elif op == "OR": + assert full_query is not None + full_query = DisjunctionQuery(full_query, next_string_query) + else: + assert full_query is not None + full_query = ExclusiveDisjunctionQuery(full_query, next_string_query) + return full_query + # This branch is for the full query version + else: + # Make sure you correctly gathered curly brace-delimited regions and + # compound operators + if len(query_list) != len(compound_ops) + 1: + raise ValueError( + "Incompatible number of curly brace elements and compound operators" + ) + # Iterate over the compound operators + full_query = None + for i, op in enumerate(compound_ops): + # If in the first iteration, set the initial query as the result + # of recursively calling this function on the first curly brace-delimited region + if i == 0: + full_query = parse_string_dialect(query_list[i]) + # Get the next query by recursively calling this function + # on the next curly brace-delimited region + next_string_query = parse_string_dialect(query_list[i + 1]) + # Add the next query to the full query using the compound operator + # currently being considered + if op == "AND": + assert full_query is not None + full_query = ConjunctionQuery(full_query, next_string_query) + elif op == "OR": + assert full_query is not None + full_query = DisjunctionQuery(full_query, next_string_query) + else: + assert full_query is not None + full_query = ExclusiveDisjunctionQuery(full_query, next_string_query) + return full_query diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index dfc5b383..26366ad7 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -5,15 +5,14 @@ from itertools import groupby import pandas as pd -from typing import List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Union, cast from .errors import InvalidQueryFilter from ..node import Node, traversal_order from ..graph import Graph from .query import Query -from .compound import CompoundQuery +from .compound import CompoundQuery, parse_string_dialect from .object_dialect import ObjectQuery -from .string_dialect import parse_string_dialect class QueryEngine: @@ -21,7 +20,7 @@ class QueryEngine: def __init__(self) -> None: """Creates the QueryEngine.""" - self.search_cache = {} + self.search_cache: Dict[int, List[int]] = {} def reset_cache(self) -> None: """Resets the cache in the QueryEngine.""" @@ -42,24 +41,26 @@ def apply( """ if issubclass(type(query), Query): self.reset_cache() - matches = [] - visited = set() + matches: List[List[Node]] = [] + visited: Set[int] = set() + casted_query = cast(Query, query) for root in sorted(graph.roots, key=traversal_order): - self._apply_impl(query, dframe, root, visited, matches) + self._apply_impl(casted_query, dframe, root, visited, matches) assert len(visited) == len(graph) matched_node_set = list(set().union(*matches)) # return matches return matched_node_set elif issubclass(type(query), CompoundQuery): results = [] - for subq in query.subqueries: + compound_query = cast(CompoundQuery, query) + for subq in compound_query.subqueries: subq_obj = subq if isinstance(subq, list): subq_obj = ObjectQuery(subq) elif isinstance(subq, str): subq_obj = parse_string_dialect(subq) results.append(self.apply(subq_obj, graph, dframe)) - return query._apply_op_to_results(results, graph) + return compound_query._apply_op_to_results(results, graph) else: raise TypeError("Invalid query data type ({})".format(str(type(query)))) @@ -184,15 +185,15 @@ def _match_pattern( if query.query_pattern[match_idx][0] == "*": pattern_idx = 0 # Starting matching pattern - matches = [[pattern_root]] + matches: List[List[Node]] = [[pattern_root]] while pattern_idx < len(query): # Get the wildcard type wcard, _ = query.query_pattern[pattern_idx] - new_matches = [] + new_matches: List[List[Node]] = [] # Consider each existing match individually so that more # nodes can be added to them. for m in matches: - sub_match = [] + sub_match: List[Optional[List[Node]]] = [] # Get the portion of the subgraph that matches the next # part of the query. if wcard == ".": @@ -217,9 +218,9 @@ def _match_pattern( ) # Merge the next part of the match path with the # existing part. - for s in sub_match: - if s is not None: - new_matches.append(m + s) + for sm in sub_match: + if sm is not None: + new_matches.append(m + sm) new_matches = [uniq_match for uniq_match, _ in groupby(new_matches)] # Overwrite the old matches with the updated matches matches = new_matches @@ -236,7 +237,7 @@ def _apply_impl( query: Query, dframe: pd.DataFrame, node: Node, - visited: Set[Node], + visited: Set[int], matches: List[List[Node]], ) -> None: """Traverse the subgraph with the specified root, and collect all paths that match the query. diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index ecdc413b..cd0ed641 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -13,6 +13,7 @@ import re import sys from typing import Dict, List, Tuple, Union +from collections.abc import Callable, Iterable from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch from ..node import Node @@ -32,14 +33,14 @@ def _process_multi_index_mode(apply_result: pd.Series, multi_index_mode: str): def _process_predicate( attr_filter: Dict[Union[str, Tuple[str, ...]], Union[str, Real]], multi_index_mode: str, -) -> bool: +) -> Callable[[Union[pd.Series, pd.DataFrame]], bool]: """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , def filter_series(df_row: pd.Series) -> bool: def filter_single_series( df_row: pd.Series, - key: Union[str, Tuple[str]], + key: Union[str, Tuple[str, ...]], single_value: Union[str, Real], ) -> bool: if key == "depth": @@ -118,14 +119,7 @@ def filter_single_series( metric_name = k if isinstance(k, (tuple, list)) and len(k) == 1: metric_name = k[0] - try: - _ = iter(v) - # Manually raise TypeError if v is a string so that - # the string is processed as a non-iterable - if isinstance(v, str): - raise TypeError - # Runs if v is not iterable (e.g., list, tuple, etc.) - except TypeError: + if isinstance(v, str) or not isinstance(v, Iterable): matches = matches and filter_single_series(df_row, metric_name, v) else: for single_value in v: @@ -208,11 +202,7 @@ def filter_single_dframe( metric_name = k if isinstance(k, (tuple, list)) and len(k) == 1: metric_name = k[0] - try: - _ = iter(v) - if isinstance(v, str): - raise TypeError - except TypeError: + if isinstance(v, str) or not isinstance(v, Iterable): matches = matches and filter_single_dframe(node, df_row, metric_name, v) else: for single_value in v: diff --git a/hatchet/query/query.py b/hatchet/query/query.py index 0cc6139e..d3be624e 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -3,8 +3,8 @@ # # SPDX-License-Identifier: MIT -from typing import Tuple -from collections.abc import Callable +from typing import List, Tuple, Union +from collections.abc import Callable, Iterator from .errors import InvalidQueryPath @@ -14,10 +14,10 @@ class Query(object): def __init__(self) -> None: """Create new Query""" - self.query_pattern = [] + self.query_pattern: List[Tuple[Union[str, int], Callable]] = [] def match( - self, quantifier: str = ".", predicate: Callable = lambda row: True + self, quantifier: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Start a query with a root node described by the arguments. @@ -34,7 +34,7 @@ def match( return self def rel( - self, quantifier: str = ".", predicate: Callable = lambda row: True + self, quantifier: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Add a new node to the end of the query. @@ -53,7 +53,7 @@ def rel( return self def relation( - self, quantifer: str = ".", predicate: Callable = lambda row: True + self, quantifer: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Alias to Query.rel. Add a new node to the end of the query. @@ -70,12 +70,12 @@ def __len__(self) -> int: """Returns the length of the query.""" return len(self.query_pattern) - def __iter__(self) -> Tuple[str, Callable]: + def __iter__(self) -> Iterator[Tuple[Union[str, int], Callable]]: """Allows users to iterate over the Query like a list.""" return iter(self.query_pattern) def _add_node( - self, quantifer: str = ".", predicate: Callable = lambda row: True + self, quantifer: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> None: """Add a node to the query. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index d013e93d..6b5f5292 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -7,7 +7,7 @@ import re import sys from collections.abc import Callable -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, List, Union, TYPE_CHECKING import pandas as pd # noqa: F401 from pandas.api.types import is_numeric_dtype, is_string_dtype # noqa: F401 import numpy as np # noqa: F401 @@ -17,7 +17,6 @@ from .errors import InvalidQueryPath, InvalidQueryFilter, RedundantQueryFilterWarning from .query import Query -from .compound import CompoundQuery # PEG grammar for the String-based dialect @@ -124,12 +123,12 @@ def __init__(self, cypher_query: str, multi_index_mode: str = "off") -> None: e.message ) ) - self.wcards = [] - self.wcard_pos = {} + self.wcards: List[List[Any]] = [] + self.wcard_pos: Dict[str, int] = {} self._parse_path(model.path_expr) - self.filters = [[] for _ in self.wcards] + self.filters: List[List[Any]] = [[] for _ in self.wcards] self._parse_conditions(model.cond_expr) - self.lambda_filters = [None for _ in self.wcards] + self.lambda_filters: List[Optional[str]] = [None for _ in self.wcards] self._build_lambdas() self._build_query() @@ -185,7 +184,7 @@ def _parse_path(self, path_obj: Any) -> None: nodes = path_obj.path.nodes idx = len(self.wcards) for n in nodes: - new_node = [n.wcard, n.name] + new_node: List[Any] = [n.wcard, n.name] if n.wcard is None or n.wcard == "" or n.wcard == 0: new_node[0] = "." self.wcards.append(new_node) @@ -231,9 +230,7 @@ def _is_binary_cond(self, obj: Any) -> bool: return True return False - def _parse_binary_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_binary_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": return self._parse_and_cond(obj) @@ -241,31 +238,25 @@ def _parse_binary_cond( return self._parse_or_cond(obj) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_or_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing predicates combined with logical OR.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_and_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing predicates combined with logical AND.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_unary_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": return self._parse_not_cond(obj) return self._parse_single_cond(obj) - def _parse_not_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_cond(self, obj: Any) -> List[Optional[str]]: """Parse predicates containing the logical NOT operator.""" converted_subcond = self._parse_single_cond(obj.subcond) converted_subcond[2] = "not {}".format(converted_subcond[2]) @@ -273,16 +264,14 @@ def _parse_not_cond( def _run_method_based_on_multi_idx_mode( self, method_name: str, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + ) -> List[Optional[str]]: real_method_name = method_name if self.multi_index_mode != "off": real_method_name = method_name + "_multi_idx" method = eval("StringQuery.{}".format(real_method_name)) return method(self, obj) - def _parse_single_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_single_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): return self._parse_str(obj) @@ -298,7 +287,7 @@ def _parse_single_cond( return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) raise RuntimeError("Bad Single Condition") - def _parse_none(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_none(self, obj: Any) -> List[Optional[str]]: """Parses 'property IS NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -330,9 +319,7 @@ def _add_aggregation_call_to_multi_idx_predicate(self, predicate: str) -> str: return predicate + ".any()" return predicate + ".all()" - def _parse_none_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_none_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -360,9 +347,7 @@ def _parse_none_multi_idx( None, ] - def _parse_not_none( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_none(self, obj: Any) -> List[Optional[str]]: """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -389,9 +374,7 @@ def _parse_not_none( None, ] - def _parse_not_none_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_none_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -419,7 +402,7 @@ def _parse_not_none_multi_idx( None, ] - def _parse_leaf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_leaf(self, obj: Any) -> List[Optional[str]]: """Parses 'node IS LEAF'.""" return [ None, @@ -428,9 +411,7 @@ def _parse_leaf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]] None, ] - def _parse_leaf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_leaf_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -438,9 +419,7 @@ def _parse_leaf_multi_idx( None, ] - def _parse_not_leaf( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_leaf(self, obj: Any) -> List[Optional[str]]: """Parses 'node IS NOT LEAF'.""" return [ None, @@ -449,9 +428,7 @@ def _parse_not_leaf( None, ] - def _parse_not_leaf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_leaf_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -487,7 +464,7 @@ def _is_num_cond(self, obj: Any) -> bool: return True return False - def _parse_str(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str(self, obj: Any) -> List[Optional[str]]: """Function that redirects processing of string predicates to the correct function. """ @@ -505,7 +482,7 @@ def _parse_str(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) raise RuntimeError("Bad String Op Class") - def _parse_str_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_eq(self, obj: Any) -> List[Optional[str]]: """Processes string equivalence predicates.""" return [ None, @@ -525,9 +502,7 @@ def _parse_str_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_str_eq_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -548,9 +523,7 @@ def _parse_str_eq_multi_idx( ), ] - def _parse_str_starts_with( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_starts_with(self, obj: Any) -> List[Optional[str]]: """Processes string 'startswith' predicates.""" return [ None, @@ -570,9 +543,7 @@ def _parse_str_starts_with( ), ] - def _parse_str_starts_with_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_starts_with_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -593,9 +564,7 @@ def _parse_str_starts_with_multi_idx( ), ] - def _parse_str_ends_with( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_ends_with(self, obj: Any) -> List[Optional[str]]: """Processes string 'endswith' predicates.""" return [ None, @@ -615,9 +584,7 @@ def _parse_str_ends_with( ), ] - def _parse_str_ends_with_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_ends_with_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -638,9 +605,7 @@ def _parse_str_ends_with_multi_idx( ), ] - def _parse_str_contains( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_contains(self, obj: Any) -> List[Optional[str]]: """Processes string 'contains' predicates.""" return [ None, @@ -660,9 +625,7 @@ def _parse_str_contains( ), ] - def _parse_str_contains_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_contains_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -683,9 +646,7 @@ def _parse_str_contains_multi_idx( ), ] - def _parse_str_match( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_match(self, obj: Any) -> List[Optional[str]]: """Processes string regex match predicates.""" return [ None, @@ -705,9 +666,7 @@ def _parse_str_match( ), ] - def _parse_str_match_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_match_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -728,7 +687,7 @@ def _parse_str_match_multi_idx( ), ] - def _parse_num(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num(self, obj: Any) -> List[Optional[str]]: """Function that redirects processing of numeric predicates to the correct function. """ @@ -752,7 +711,7 @@ def _parse_num(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) raise RuntimeError("Bad Number Op Class") - def _parse_num_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_eq(self, obj: Any) -> List[Optional[str]]: """Processes numeric equivalence predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: @@ -825,9 +784,7 @@ def _parse_num_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_eq_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ @@ -903,7 +860,7 @@ def _parse_num_eq_multi_idx( ), ] - def _parse_num_lt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lt(self, obj: Any) -> List[Optional[str]]: """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -969,9 +926,7 @@ def _parse_num_lt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_lt_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lt_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1040,7 +995,7 @@ def _parse_num_lt_multi_idx( ), ] - def _parse_num_gt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gt(self, obj: Any) -> List[Optional[str]]: """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1106,9 +1061,7 @@ def _parse_num_gt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_gt_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gt_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1177,7 +1130,7 @@ def _parse_num_gt_multi_idx( ), ] - def _parse_num_lte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lte(self, obj: Any) -> List[Optional[str]]: """Processes numeric less-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1243,9 +1196,7 @@ def _parse_num_lte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_lte_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lte_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1314,7 +1265,7 @@ def _parse_num_lte_multi_idx( ), ] - def _parse_num_gte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gte(self, obj: Any) -> List[Optional[str]]: """Processes numeric greater-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1380,9 +1331,7 @@ def _parse_num_gte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_gte_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gte_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1451,7 +1400,7 @@ def _parse_num_gte_multi_idx( ), ] - def _parse_num_nan(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_nan(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1482,9 +1431,7 @@ def _parse_num_nan(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_nan_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_nan_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1516,9 +1463,7 @@ def _parse_num_nan_multi_idx( ), ] - def _parse_num_not_nan( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_nan(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1549,9 +1494,7 @@ def _parse_num_not_nan( ), ] - def _parse_num_not_nan_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_nan_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1583,7 +1526,7 @@ def _parse_num_not_nan_multi_idx( ), ] - def _parse_num_inf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_inf(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1614,9 +1557,7 @@ def _parse_num_inf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_inf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_inf_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1648,9 +1589,7 @@ def _parse_num_inf_multi_idx( ), ] - def _parse_num_not_inf( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_inf(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1681,9 +1620,7 @@ def _parse_num_not_inf( ), ] - def _parse_num_not_inf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_inf_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1714,162 +1651,3 @@ def _parse_num_not_inf_multi_idx( else "'{}'".format(obj.prop.ids[0]) ), ] - - -def parse_string_dialect( - query_str: str, multi_index_mode: str = "off" -) -> Union[StringQuery, CompoundQuery]: - """Parse all types of String-based queries, including multi-queries that leverage - the curly brace delimiters. - - Arguments: - query_str (str): the String-based query to be parsed - - Returns: - (Query or CompoundQuery): A Hatchet query object representing the String-based query - """ - # TODO Check if there's a way to prevent curly braces in a string - # from being captured - - # Find the number of curly brace-delimited regions in the query - query_str = query_str.strip() - curly_brace_elems = re.findall(r"\{(.*?)\}", query_str) - num_curly_brace_elems = len(curly_brace_elems) - # If there are no curly brace-delimited regions, just pass the query - # off to the CypherQuery constructor - if num_curly_brace_elems == 0: - if sys.version_info[0] == 2: - query_str = query_str.decode("utf-8") - return StringQuery(query_str, multi_index_mode) - # Create an iterator over the curly brace-delimited regions - curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) - # Will store curly brace-delimited regions in the WHERE clause - condition_list = None - # Will store curly brace-delimited regions that contain entire - # mid-level queries (MATCH clause and WHERE clause) - query_list = None - # If entire queries are in brace-delimited regions, store the indexes - # of the regions here so we don't consider brace-delimited regions - # within the already-captured region. - query_idxes = None - # Store which compound queries to apply to the curly brace-delimited regions - compound_ops = [] - for i, match in enumerate(curly_brace_iter): - # Get the substring within curly braces - substr = query_str[match.start() + 1 : match.end() - 1] - substr = substr.strip() - # If an entire query (MATCH + WHERE) is within curly braces, - # add the query to "query_list", and add the indexes corresponding - # to the query to "query_idxes" - if substr.startswith("MATCH"): - if query_list is None: - query_list = [] - if query_idxes is None: - query_idxes = [] - query_list.append(substr) - query_idxes.append((match.start(), match.end())) - # If the curly brace-delimited region contains only parts of a - # WHERE clause, first, check if the region is within another - # curly brace delimited region. If it is, do nothing (it will - # be handled later). Otherwise, add the region to "condition_list" - elif re.match(r"[a-zA-Z0-9_]+\..*", substr) is not None: - is_encapsulated_region = False - if query_idxes is not None: - for s, e in query_idxes: - if match.start() >= s or match.end() <= e: - is_encapsulated_region = True - break - if is_encapsulated_region: - continue - if condition_list is None: - condition_list = [] - condition_list.append(substr) - # If the curly brace-delimited region is neither a whole query - # or part of a WHERE clause, raise an error - else: - raise ValueError("Invalid grouping (with curly braces) within the query") - # If there is a compound operator directly after the curly brace-delimited region, - # capture the type of operator, and store the type in "compound_ops" - if i + 1 < num_curly_brace_elems: - rest_substr = query_str[match.end() :] - rest_substr = rest_substr.strip() - if rest_substr.startswith("AND"): - compound_ops.append("AND") - elif rest_substr.startswith("OR"): - compound_ops.append("OR") - elif rest_substr.startswith("XOR"): - compound_ops.append("XOR") - else: - raise ValueError("Invalid compound operator type found!") - # Each call to this function should only consider one of the full query or - # WHERE clause versions at a time. If both types were captured, raise an error - # because some type of internal logic issue occured. - if condition_list is not None and query_list is not None: - raise ValueError( - "Curly braces must be around either a full mid-level query or a set of conditions in a single mid-level query" - ) - # This branch is for the WHERE clause version - if condition_list is not None: - # Make sure you correctly gathered curly brace-delimited regions and - # compound operators - if len(condition_list) != len(compound_ops) + 1: - raise ValueError( - "Incompatible number of curly brace elements and compound operators" - ) - # Get the MATCH clause that will be shared across the subqueries - match_comp_obj = re.search(r"MATCH\s+(?P.*)\s+WHERE", query_str) - match_comp = match_comp_obj.group("match_field") - # Iterate over the compound operators - full_query = None - for i, op in enumerate(compound_ops): - # If in the first iteration, set the initial query as a CypherQuery where - # the MATCH clause is the shared match clause and the WHERE clause is the - # first curly brace-delimited region - if i == 0: - query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) - if sys.version_info[0] == 2: - query1 = query1.decode("utf-8") - full_query = StringQuery(query1, multi_index_mode) - # Get the next query as a CypherQuery where - # the MATCH clause is the shared match clause and the WHERE clause is the - # next curly brace-delimited region - next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) - if sys.version_info[0] == 2: - next_query = next_query.decode("utf-8") - next_query = StringQuery(next_query, multi_index_mode) - # Add the next query to the full query using the compound operator - # currently being considered - if op == "AND": - full_query = full_query & next_query - elif op == "OR": - full_query = full_query | next_query - else: - full_query = full_query ^ next_query - return full_query - # This branch is for the full query version - else: - # Make sure you correctly gathered curly brace-delimited regions and - # compound operators - if len(query_list) != len(compound_ops) + 1: - raise ValueError( - "Incompatible number of curly brace elements and compound operators" - ) - # Iterate over the compound operators - full_query = None - for i, op in enumerate(compound_ops): - # If in the first iteration, set the initial query as the result - # of recursively calling this function on the first curly brace-delimited region - if i == 0: - full_query = parse_string_dialect(query_list[i]) - # Get the next query by recursively calling this function - # on the next curly brace-delimited region - next_query = parse_string_dialect(query_list[i + 1]) - # Add the next query to the full query using the compound operator - # currently being considered - if op == "AND": - full_query = full_query & next_query - elif op == "OR": - full_query = full_query | next_query - else: - full_query = full_query ^ next_query - return full_query diff --git a/hatchet/readers/caliper_native_reader.py b/hatchet/readers/caliper_native_reader.py index bfb95aa2..04ae2e63 100644 --- a/hatchet/readers/caliper_native_reader.py +++ b/hatchet/readers/caliper_native_reader.py @@ -7,7 +7,8 @@ import pandas as pd import numpy as np import os -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast +from collections.abc import Callable import caliperreader as cr @@ -60,24 +61,26 @@ def __init__( native (bool): use native metric names or user-readable metric names string_attributes (str or list): Adds existing string attributes from within the caliper file to the dataframe """ - self.filename_or_caliperreader = filename_or_caliperreader + self.filename_or_caliperreader: Union[str, cr.CaliperReader] = ( + filename_or_caliperreader + ) self.filename_ext = "" self.use_native_metric_names = native self.string_attributes = string_attributes - self.df_nodes = {} - self.metric_cols = [] - self.record_data_cols = [] - self.node_dicts = [] - self.callpath_to_node = {} - self.idx_to_node = {} - self.callpath_to_idx = {} - self.global_nid = 0 - self.node_ordering = False - self.gf_list = [] - self.timeseries_level = None + self.df_nodes: Optional[pd.DataFrame] = None + self.metric_cols: List[str] = [] + self.record_data_cols: List[str] = [] + # self.node_dicts = [] + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} + self.idx_to_node: Dict[int, Dict[str, Any]] = {} + self.callpath_to_idx: Dict[Tuple[str, ...], int] = {} + self.global_nid: int = 0 + self.node_ordering: bool = False + self.gf_list: List[hatchet.graphframe.GraphFrame] = [] + self.timeseries_level: Optional[str] = None - self.default_metric = None + self.default_metric: Optional[str] = None self.timer = Timer() @@ -87,8 +90,9 @@ def __init__( if isinstance(self.string_attributes, str): self.string_attributes = [self.string_attributes] - def _create_metric_df(self, metrics: List[str]) -> pd.DataFrame: + def _create_metric_df(self, metrics: List[Dict[str, Any]]) -> pd.DataFrame: """Make a list of metric columns and create a dataframe, group by node""" + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) for col in self.record_data_cols: if self.filename_or_caliperreader.attribute(col).is_value(): self.metric_cols.append(col) @@ -114,8 +118,9 @@ def _reset_metrics(self, metrics: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: """append each metrics table to a list and return the list, split on timeseries_level if exists""" + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) metric_dfs = [] - all_metrics = [] + all_metrics: List[Dict[str, Any]] = [] next_timestep = 0 cur_timestep = 0 records = self.filename_or_caliperreader.records @@ -175,9 +180,12 @@ def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: or item in self.string_attributes ): try: - node_dict[item] = self.__cali_type_dict[ - attr_type - ](record[item]) + node_dict[item] = ( + cast( + Callable, + self.__cali_type_dict[attr_type], + )(record[item]), + ) if item not in self.record_data_cols: self.record_data_cols.append(item) except ValueError as e: @@ -199,9 +207,10 @@ def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: return metric_dfs def create_graph(self, ctx: str = "path") -> List[Node]: + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) list_roots = [] - def _create_parent(child_node: Node, parent_callpath: Any) -> None: + def _create_parent(child_node: Node, parent_callpath: Tuple[str, ...]) -> None: """We may encounter a parent node in the callpath before we see it as a child node. In this case, we need to create a hatchet node for the parent. @@ -353,7 +362,9 @@ def _create_parent(child_node: Node, parent_callpath: Any) -> None: return list_roots - def _parse_metadata(self, mdata: Dict[str, str]) -> Dict[str, str]: + def _parse_metadata( + self, mdata: Dict[str, Union[List, str]] + ) -> Dict[str, Union[List, int, float, str]]: """Convert Caliper Metadata values into correct Python objects. Args: @@ -362,7 +373,7 @@ def _parse_metadata(self, mdata: Dict[str, str]) -> Dict[str, str]: Return: (dict[str: str]): modified metadata """ - parsed_mdata = {} + parsed_mdata: Dict[str, Union[List, int, float, str]] = {} for k, v in mdata.items(): # environment information service brings in different metadata types if isinstance(v, list): @@ -428,7 +439,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: rank_list = range(0, num_ranks) # create a standard dict to be used for filling all missing rows - default_metric_dict = {} + default_metric_dict: Dict[str, Any] = {} for idx, col in enumerate(self.record_data_cols): if self.filename_or_caliperreader.attribute(col).is_value(): default_metric_dict[list(self.record_data_cols)[idx]] = 0 @@ -437,7 +448,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: default_metric_dict["nid"] = np.nan # create a list of dicts, one dict for each missing row - missing_nodes = [] + missing_nodes: List[Dict[str, Any]] = [] for iteridx, row in self.df_nodes.iterrows(): # check if df_nodes row exists in df_fixed_data metric_rows = df_fixed_data.loc[metrics["nid"] == row["nid"]] diff --git a/hatchet/readers/caliper_reader.py b/hatchet/readers/caliper_reader.py index a298f70f..e8ae25f3 100644 --- a/hatchet/readers/caliper_reader.py +++ b/hatchet/readers/caliper_reader.py @@ -9,7 +9,8 @@ import subprocess import os import math -from typing import List, Union +from typing import Any, Dict, List, Union, cast +from collections.abc import Callable from io import TextIOWrapper import pandas as pd @@ -43,21 +44,21 @@ def __init__( self.query = query self.node_ordering = False - self.json_data = {} - self.json_cols = {} - self.json_cols_mdata = {} - self.json_nodes = {} + self.json_data: List[List[Union[int, float]]] = [] + self.json_cols: List[str] = [] + self.json_cols_mdata: List[Dict[str, Any]] = [] + self.json_nodes: List[Dict[str, Any]] = [] - self.metadata = {} + self.metadata: Dict[str, Any] = {} - self.idx_to_label = {} - self.idx_to_node = {} + self.idx_to_label: Dict[int, str] = {} + self.idx_to_node: Dict[int, Dict[str, Union[int, str, Node]]] = {} self.timer = Timer() self.nid_col_name = "nid" if isinstance(self.filename_or_stream, str): - _, self.filename_ext = os.path.splitext(filename_or_stream) + _, self.filename_ext = os.path.splitext(cast(str, filename_or_stream)) def read_json_sections(self) -> None: # if cali-query exists, extract data from .cali to a file-like object @@ -65,19 +66,20 @@ def read_json_sections(self) -> None: cali_query = which("cali-query") if not cali_query: raise ValueError("from_caliper() needs cali-query to query .cali file") - cali_json = subprocess.Popen( + assert isinstance(self.filename_or_stream, str) + cali_json_popen = subprocess.Popen( [cali_query, "-q", self.query, self.filename_or_stream], stdout=subprocess.PIPE, ) - self.filename_or_stream = cali_json.stdout + self.filename_or_stream = str(cali_json_popen.stdout) # if filename_or_stream is a str, then open the file, otherwise # directly load the file-like object if isinstance(self.filename_or_stream, str): - with open(self.filename_or_stream) as cali_json: + with open(cast(str, self.filename_or_stream)) as cali_json: json_obj = json.load(cali_json) else: - json_obj = json.loads(self.filename_or_stream.read().decode("utf-8")) + json_obj = json.loads(self.filename_or_stream.read()) # read various sections of the Caliper JSON file self.json_data = json_obj["data"] @@ -121,27 +123,30 @@ def read_json_sections(self) -> None: self.json_data.remove(i) # change column names - for idx, item in enumerate(self.json_cols): - if item == self.path_col_name: + for idx, col_item in enumerate(self.json_cols): + if col_item == self.path_col_name: # this column is just a pointer into the nodes section self.json_cols[idx] = self.nid_col_name # make other columns consistent with other readers - if item == "mpi.rank": + if col_item == "mpi.rank": self.json_cols[idx] = "rank" - if item == "module#cali.sampler.pc": + if col_item == "module#cali.sampler.pc": self.json_cols[idx] = "module" - if item == "sum#time.duration" or item == "sum#avg#sum#time.duration": + if ( + col_item == "sum#time.duration" + or col_item == "sum#avg#sum#time.duration" + ): self.json_cols[idx] = "time" if ( - item == "inclusive#sum#time.duration" - or item == "sum#avg#inclusive#sum#time.duration" + col_item == "inclusive#sum#time.duration" + or col_item == "sum#avg#inclusive#sum#time.duration" ): self.json_cols[idx] = "time (inc)" # make list of metric columns - self.metric_columns = [] - for idx, item in enumerate(self.json_cols_mdata): - if self.json_cols[idx] != "rank" and item["is_value"] is True: + self.metric_columns: List[str] = [] + for idx, col_mdata_item in enumerate(self.json_cols_mdata): + if self.json_cols[idx] != "rank" and col_mdata_item["is_value"] is True: self.metric_columns.append(self.json_cols[idx]) def create_graph(self) -> List[Node]: @@ -162,7 +167,7 @@ def create_graph(self) -> List[Node]: # If there is a node orderering, assign to the _hatchet_nid if "Node order" in self.json_cols: self.node_ordering = True - order = self.json_data[idx][0] + order = cast(int, self.json_data[idx][0]) if "parent" not in node: # since this node does not have a parent, this is a root graph_root = Node( @@ -177,7 +182,9 @@ def create_graph(self) -> List[Node]: } self.idx_to_node[idx] = node_dict else: - parent_hnode = (self.idx_to_node[node["parent"]])["node"] + parent_hnode = cast( + Node, (self.idx_to_node[node["parent"]])["node"] + ) hnode = Node( Frame({"type": self.node_type, "name": node_label}), hnid=order, @@ -214,7 +221,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: if self.both_hierarchies is True: # create dict that stores aggregation function for each column - agg_dict = {} + agg_dict: Dict[str, Callable] = {} for idx, item in enumerate(self.json_cols_mdata): col = self.json_cols[idx] if col != "rank" and col != "nid": @@ -285,7 +292,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: # only need to do something if there are more than one # file:line number entries for the node if len(line_groups.size()) > 1: - sn_hnode = self.idx_to_node[nid]["node"] + sn_hnode = cast(Node, self.idx_to_node[nid]["node"]) for line, line_group in line_groups: # create the node label diff --git a/hatchet/readers/dataframe_reader.py b/hatchet/readers/dataframe_reader.py index 50d21e0d..1caea64b 100644 --- a/hatchet/readers/dataframe_reader.py +++ b/hatchet/readers/dataframe_reader.py @@ -9,20 +9,9 @@ import pandas as pd -from abc import abstractmethod +from abc import abstractmethod, ABC from typing import Dict, List -# TODO The ABC class was introduced in Python 3.4. -# When support for earlier versions is (eventually) dropped, -# this entire "try-except" block can be reduced to: -# from abc import ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) - def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: node = None @@ -38,7 +27,7 @@ def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: def _get_parents_and_children(df: pd.DataFrame) -> Dict[Node, Dict[str, List[int]]]: - rel_dict = {} + rel_dict: Dict[Node, Dict[str, List[int]]] = {} for i in range(len(df)): node = _get_node_from_df_iloc(df, i) if node not in rel_dict: diff --git a/hatchet/readers/gprof_dot_reader.py b/hatchet/readers/gprof_dot_reader.py index 981388cc..76ccadcb 100644 --- a/hatchet/readers/gprof_dot_reader.py +++ b/hatchet/readers/gprof_dot_reader.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT import re -from typing import List +from typing import Dict, List, Union import pandas as pd import pydot @@ -23,8 +23,8 @@ class GprofDotReader: def __init__(self, filename: str) -> None: self.dotfile = filename - self.name_to_hnode = {} - self.name_to_dict = {} + self.name_to_hnode: Dict[str, Node] = {} + self.name_to_dict: Dict[str, Dict[str, Union[str, Node]]] = {} self.timer = Timer() diff --git a/hatchet/readers/hpctoolkit_reader.py b/hatchet/readers/hpctoolkit_reader.py index 7d72ce82..2ddf57b3 100644 --- a/hatchet/readers/hpctoolkit_reader.py +++ b/hatchet/readers/hpctoolkit_reader.py @@ -8,7 +8,7 @@ import re import os import traceback -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -18,11 +18,11 @@ try: import xml.etree.cElementTree as ET except ImportError: - import xml.etree.ElementTree as ET + import xml.etree.ElementTree as ET # type: ignore[no-redef] # cython imports try: - import hatchet.cython_modules.libs.reader_modules as _crm + import hatchet.cython_modules.libs.reader_modules as _crm # type: ignore except ImportError: print("-" * 80) print( @@ -40,7 +40,8 @@ from hatchet.frame import Frame -src_file = 0 +src_file: Optional[str] = None +shared_metrics: Optional[Any] = None # TODO replace the "Any" type hint with numpy.typing.ArrayLike @@ -51,7 +52,7 @@ def init_shared_array(buf_: Any) -> None: shared_metrics = buf_ -def read_metricdb_file(args: Tuple[str, int, int, int, int, Tuple[int, int]]) -> None: +def read_metricdb_file(args: Tuple[str, int, int, int, int, List[int]]) -> None: """Read a single metricdb file into a 1D array.""" ( filename, @@ -85,7 +86,7 @@ def read_metricdb_file(args: Tuple[str, int, int, int, int, Tuple[int, int]]) -> rank * num_threads_per_rank + num_cpu_threads_per_rank + (thread - 500) ) * num_nodes - arr[rank_offset : rank_offset + num_nodes, :num_metrics].flat = arr1d.flat + arr[rank_offset : rank_offset + num_nodes, :num_metrics].flat = arr1d.flat # type: ignore[misc] arr[rank_offset : rank_offset + num_nodes, num_metrics] = range(1, num_nodes + 1) arr[rank_offset : rank_offset + num_nodes, num_metrics + 1] = rank arr[rank_offset : rank_offset + num_nodes, num_metrics + 2] = thread @@ -136,17 +137,17 @@ def __init__(self, dir_name: str) -> None: self.num_metrics = struct.unpack(">i", metricdb.read(4))[0] else: raise ValueError( - "HPCToolkitReader doesn't support endian '%s'" % endian + "HPCToolkitReader doesn't support endian '{:r}'".format(endian) ) - self.load_modules = {} - self.src_files = {} - self.procedure_names = {} - self.metric_names = {} + self.load_modules: Dict = {} + self.src_files: Dict = {} + self.procedure_names: Dict = {} + self.metric_names: Dict = {} # this list of dicts will hold all the node information such as # procedure name, load module, filename, etc. for all the nodes - self.node_dicts = [] + self.node_dicts: List[Dict[str, Union[int, str, Node]]] = [] self.timer = Timer() @@ -428,7 +429,7 @@ def create_node_dict( src_file: str, line: int, module: str, - ) -> Dict[str, Union[int, str, Node]]: + ) -> Dict[str, Any]: """Create a dict with all the node attributes.""" node_dict = { "nid": nid, diff --git a/hatchet/readers/hpctoolkit_reader_latest.py b/hatchet/readers/hpctoolkit_reader_latest.py index 97c1cd46..d201912b 100644 --- a/hatchet/readers/hpctoolkit_reader_latest.py +++ b/hatchet/readers/hpctoolkit_reader_latest.py @@ -6,7 +6,7 @@ import os import re import struct -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -68,18 +68,18 @@ def __init__( self._meta_file = None self._profile_file = None - self._functions = {} - self._source_files = {} - self._load_modules = {} - self._metric_descriptions = {} - self._summary_profile = {} + self._functions: Dict[int, Dict[str, Any]] = {} + self._source_files: Dict[int, Dict[str, Any]] = {} + self._load_modules: Dict[int, Dict[str, Any]] = {} + self._metric_descriptions: Dict = {} + self._summary_profile: Dict = {} - self._time_metric = None - self._inclusive_metrics = {} - self._exclusive_metrics = {} + self._time_metric: Optional[str] = None + self._inclusive_metrics: Dict = {} + self._exclusive_metrics: Dict = {} - self._cct_roots = [] - self._metrics_table = [] + self._cct_roots: List[Node] = [] + self._metrics_table: List[Dict[str, Any]] = [] for file_path in os.listdir(self._dir_path): if file_path.split(".")[-1] == "db": @@ -278,7 +278,7 @@ def _parse_context( ): continue - frame = {"type": NODE_TYPE_MAPPING[lexicalType]} + frame: Dict[str, Union[str, int]] = {"type": NODE_TYPE_MAPPING[lexicalType]} if nFlexWords: if lexicalType == 0: @@ -365,7 +365,7 @@ def _read_summary_profile( def _read_cct( self, - ) -> None: + ) -> Optional[GraphFrame]: with open(self._meta_file, "rb") as file: meta_db = file.read() @@ -410,7 +410,7 @@ def _read_cct( if im in table.columns.tolist(): inclusive_metrics.append(im) - for em in (list(self._exclusive_metrics.values()),): + for em in list(self._exclusive_metrics.values()): if em in table.columns.tolist(): exclusive_metrics.append(em) @@ -424,8 +424,9 @@ def _read_cct( print("DATA IMPORTED") return graphframe + return None - def read(self) -> GraphFrame: + def read(self) -> Optional[GraphFrame]: self._read_metric_descriptions() self._read_summary_profile() return self._read_cct() diff --git a/hatchet/readers/literal_reader.py b/hatchet/readers/literal_reader.py index f3e3fe1f..8f2b5f71 100644 --- a/hatchet/readers/literal_reader.py +++ b/hatchet/readers/literal_reader.py @@ -3,7 +3,8 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, List +from typing import Any, Dict, List, cast +from collections.abc import Iterable import pandas as pd @@ -61,7 +62,7 @@ class LiteralReader: (GraphFrame): graphframe containing data from dictionaries """ - def __init__(self, graph_dict: Dict) -> None: + def __init__(self, graph_dict: List[Dict]) -> None: """Read from list of dictionaries. graph_dict (dict): List of dictionaries encoding nodes. @@ -156,7 +157,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: graph = Graph(list_roots) # test if nids are already loaded - if -1 in [n._hatchet_nid for n in graph.traverse()]: + if -1 in [n._hatchet_nid for n in cast(Iterable[Node], graph.traverse())]: graph.enumerate_traverse() else: graph.enumerate_depth() diff --git a/hatchet/readers/pyinstrument_reader.py b/hatchet/readers/pyinstrument_reader.py index cc739a76..621cea1b 100644 --- a/hatchet/readers/pyinstrument_reader.py +++ b/hatchet/readers/pyinstrument_reader.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT import json -from typing import Any, Dict +from typing import Any, Dict, List import pandas as pd @@ -17,9 +17,9 @@ class PyinstrumentReader: def __init__(self, filename: str) -> None: self.pyinstrument_json_filename = filename - self.graph_dict = {} - self.list_roots = [] - self.node_dicts = [] + self.graph_dict: Dict[str, Any] = {} + self.list_roots: List[Node] = [] + self.node_dicts: List[Dict[str, Any]] = [] def create_graph(self) -> Graph: def parse_node_literal(child_dict: Dict[str, Any], hparent: Node) -> None: diff --git a/hatchet/readers/spotdb_reader.py b/hatchet/readers/spotdb_reader.py index 5352019e..dccd973a 100644 --- a/hatchet/readers/spotdb_reader.py +++ b/hatchet/readers/spotdb_reader.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional, List, Set import pandas as pd @@ -50,9 +50,9 @@ def __init__( self.regionprofile = regionprofile self.attr_info = attr_info self.metadata = metadata - self.df_data = [] - self.roots = {} - self.metric_columns = set() + self.df_data: List[Dict[str, Any]] = [] + self.roots: Dict[str, Node] = {} + self.metric_columns: Set[str] = set() self.timer = Timer() @@ -178,7 +178,7 @@ def read(self) -> List[hatchet.graphframe.GraphFrame]: Returns: List of GraphFrames, one for each entry that was found """ - import spotdb + import spotdb # type: ignore[import-not-found] if isinstance(self.db_key, str): db = spotdb.connect(self.db_key) diff --git a/hatchet/readers/tau_reader.py b/hatchet/readers/tau_reader.py index 4d192a0a..8cebab25 100644 --- a/hatchet/readers/tau_reader.py +++ b/hatchet/readers/tau_reader.py @@ -6,7 +6,8 @@ import re import os import glob -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union, cast +from collections.abc import Iterable import pandas as pd import hatchet.graphframe from hatchet.node import Node @@ -19,13 +20,13 @@ class TAUReader: def __init__(self, dirname: str) -> None: self.dirname = dirname - self.node_dicts = [] - self.callpath_to_node = {} - self.rank_thread_to_data = {} - self.filepath_to_data = {} - self.inc_metrics = [] - self.exc_metrics = [] - self.columns = [] + self.node_dicts: List[Dict[str, Any]] = [] + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} + # self.rank_thread_to_data = {} + # self.filepath_to_data = {} + self.inc_metrics: List[str] = [] + self.exc_metrics: List[str] = [] + self.columns: List[str] = [] self.multiple_ranks = False self.multiple_threads = False @@ -33,7 +34,7 @@ def create_node_dict( self, node: Node, columns: List[str], - metric_values: Tuple[Any, ...], + metric_values: Union[List[Any], Tuple[Any, ...]], name: str, filename: str, module: str, @@ -59,7 +60,7 @@ def create_node_dict( def create_graph(self) -> List[Node]: def _get_name_file_module( is_parent: bool, node_info: str, symbol: str - ) -> Tuple[str, str, str]: + ) -> List[str]: """This function gets the name, file and module information for a node using the corresponding line in the output file. Example line: [UNWIND] [@] [{} {}] @@ -74,18 +75,18 @@ def _get_name_file_module( # formats. Example formats are given in comments. if symbol == " [@] ": # Check if there is a [@] symbol. - node_info = node_info.split(symbol) + split_node_info = node_info.split(symbol) # We don't need file and module information if it's a parent node. if not is_parent: - file = node_info[0].split()[1] - if "[{" in node_info[1]: + file = split_node_info[0].split()[1] + if "[{" in split_node_info[1]: # Sometimes we see file and module information inside of [{}] # Example: [UNWIND] [@] [{} {}] - name_and_module = node_info[1].split(" [{") + name_and_module = split_node_info[1].split(" [{") module = name_and_module[1].split()[0].strip("}") else: # Example: [UNWIND] [@] - name_and_module = node_info[1].split() + name_and_module = split_node_info[1].split() module = name_and_module[1] # Check if module is in file. @@ -99,46 +100,46 @@ def _get_name_file_module( name = "[UNWIND] " + name_and_module[0] else: # We just need to take name if it is a parent - name = "[UNWIND] " + node_info[1].split()[0] + name = "[UNWIND] " + split_node_info[1].split()[0] elif symbol == " C ": # Check if there is a C symbol. # "C" symbol means it's a C function. - node_info = node_info.split(symbol) - name = node_info[0] + split_node_info = node_info.split(symbol) + name = split_node_info[0] # We don't need file and module information if it's a parent node. if not is_parent: - if "[{" in node_info[1]: + if "[{" in split_node_info[1]: # Example: C [{} {}] - node_info = node_info[1].split() - file = node_info[0].strip("}[{") + split_node_info = split_node_info[1].split() + file = split_node_info[0].strip("}[{") else: if "[{" in node_info: # If there isn't C or [@] # Example: [] [{} {}] - node_info = node_info.split(" [{") - name = node_info[0] + split_node_info = node_info.split(" [{") + name = split_node_info[0] # We don't need file and module information if it's a parent node. if not is_parent: - file = node_info[1].split()[0].strip("}{") + file = split_node_info[1].split()[0].strip("}{") else: # Example 1: [] # Example 2: [] # Example 3: name = node_info - node_info = node_info.split() + split_node_info = node_info.split() # We need to take module information from the first example. # Another example is "[CONTEXT] .TAU application" which contradicts # with the first example. So we check if there is "\" symbol which # will show the module information in this case. - if len(node_info) == 3 and "/" in name: - name = node_info[0] + " " + node_info[1] + if len(split_node_info) == 3 and "/" in name: + name = split_node_info[0] + " " + split_node_info[1] # We don't need file and module information if it's a parent node. if not is_parent: - module = node_info[2] + module = split_node_info[2] return [name, file, module] - def _get_line_numbers(node_info: str) -> Tuple[str, str]: - start_line, end_line = 0, 0 + def _get_line_numbers(node_info: str) -> List[str]: + start_line, end_line = "0", "0" # There should be [{}] symbols if there is line number information. if "[{" in node_info: tmp_module_or_file_line = ( @@ -149,9 +150,9 @@ def _get_line_numbers(node_info: str) -> Tuple[str, str]: if "-" in line_numbers: # Sometimes there is "-" between start line and end line # Example: {341,1}-{396,1} - line_numbers = line_numbers.split("-") - start_line = line_numbers[0].split(",")[0] - end_line = line_numbers[1].split(",")[0] + split_line_numbers = line_numbers.split("-") + start_line = split_line_numbers[0].split(",")[0] + end_line = split_line_numbers[1].split(",")[0] else: if "," in line_numbers: # Sometimes we don't have "-". @@ -160,7 +161,7 @@ def _get_line_numbers(node_info: str) -> Tuple[str, str]: end_line = line_numbers.split(",")[1] return [start_line, end_line] - def _create_parent(child_node: Node, parent_callpath: str) -> None: + def _create_parent(child_node: Node, parent_callpath: Tuple[str, ...]) -> None: """In TAU output, sometimes we see a node as a parent in the callpath before we see it as a leaf node. In this case, we need to create a hatchet node for the parent. @@ -210,7 +211,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: all metric files of a rank as a tuple and only loads the second line (metadata) of these files. """ - columns = [] + columns: List[str] = [] for file_index in range(len(first_rank_filenames)): with open(first_rank_filenames[file_index], "r") as f: # Skip the first line: "192 templated_functions_MULTI_TIME" @@ -252,7 +253,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # Each tuple stores all the metric files of a rank. # We process one rank at a time. # Example: [(metric1/profile.x.0.0, metric2/profile.x.0.0), ...] - profile_filenames = list(zip(*profile_filenames)) + profile_filenames = list(cast(Iterable[List[str]], zip(*profile_filenames))) # Get column information from the metric files of a rank. self.columns = _construct_column_list(profile_filenames[0]) @@ -281,7 +282,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: root_line = re.match(r"\"(.*)\"\s(.*)\sG", file_data[0][0]) root_name = root_line.group(1).strip(" ") # convert it to a tuple to use it as a key in callpath_to_node dictionary - root_callpath = tuple([root_name]) + root_callpath: Tuple[str, ...] = tuple([root_name]) root_values = list(map(int, root_line.group(2).split(" ")[:-1])) # After first profile.0.0.0, only get Excl and Incl metric values @@ -343,7 +344,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # Example: ".TAU application => foo() => bar()" 31 0 155019 155019 0 GROUP="TAU_SAMPLE|TAU_CALLPATH" callpath_line_regex = re.match(r"\"(.*)\"\s(.*)\sG", line) # callpath: ".TAU application => foo() => bar()" - callpath = [ + callpath: Union[List[str], Tuple[str, ...]] = [ name.strip(" ") for name in callpath_line_regex.group(1).split("=>") ] @@ -435,9 +436,9 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # module leaf_name_file_module[2], # start line - leaf_line_numbers[0], + int(leaf_line_numbers[0]), # end line - leaf_line_numbers[1], + int(leaf_line_numbers[1]), rank, thread, ) diff --git a/hatchet/readers/timemory_reader.py b/hatchet/readers/timemory_reader.py index e4ff0d91..cf4ec115 100644 --- a/hatchet/readers/timemory_reader.py +++ b/hatchet/readers/timemory_reader.py @@ -9,7 +9,7 @@ import glob import re from io import TextIOWrapper -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from hatchet.graphframe import GraphFrame from ..node import Node from ..graph import Graph @@ -22,7 +22,7 @@ class TimemoryReader: def __init__( self, - input: Union[str, TextIOWrapper, Dict], + timemory_input: Union[str, TextIOWrapper, Dict], select: Optional[List[str]] = None, **_kwargs, ) -> None: @@ -47,18 +47,20 @@ def __init__( identical name/file/line/etc. info but from different ranks are not combined """ - self.graph_dict = {"timemory": {}} - self.input = input - self.default_metric = None + self.graph_dict: Dict[str, Dict[str, Any]] = {"timemory": {}} + self.input = timemory_input + self.default_metric: Optional[str] = None self.timer = Timer() - self.metric_cols = [] - self.properties = {} + self.metric_cols: List[str] = [] + self.properties: Dict[str, Any] = {} self.include_tid = True self.include_nid = True self.multiple_ranks = False self.multiple_threads = False - self.callpath_to_node_dict = {} # (callpath, rank, thread): - self.callpath_to_node = {} # (callpath): + self.callpath_to_node_dict: Dict[ + Tuple, Dict[str, Any] + ] = {} # (callpath, rank, thread): + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} # (callpath): # the per_thread and per_rank settings make sure that # squashing doesn't collapse the threads/ranks @@ -271,7 +273,7 @@ def parse_node( _node_data: Dict[str, Any], _hparent: Node, _rank: int, - _parent_callpath: Tuple[str], + _parent_callpath: Tuple[str, ...], ) -> None: """Create callpath_to_node_dict for one node and then call the function recursively on all children. @@ -296,7 +298,7 @@ def parse_node( _prop = self.properties[_metric_name] _frame_attrs, _extra = get_name_line_file(_node_data["node"]["prefix"]) - callpath = _parent_callpath + (_frame_attrs["name"],) + callpath: Tuple[str, ...] = _parent_callpath + (_frame_attrs["name"],) # check if the node already exits. _hnode = self.callpath_to_node.get(callpath) @@ -315,7 +317,9 @@ def parse_node( # for the Frame(_keys) effectively circumvent Hatchet's # default behavior of combining similar thread/rank entries _tid_dict = _frame_attrs if self.per_thread else _extra - _rank_dict = _frame_attrs if self.per_rank else _extra + _rank_dict: Dict[str, Union[str, int]] = cast( + Dict[str, Union[str, int]], _frame_attrs if self.per_rank else _extra + ) # handle the rank _rank_dict["rank"] = collapse_ids(_rank, self.per_rank) @@ -324,10 +328,10 @@ def parse_node( self.include_nid = False # extract some relevant data - _tid_dict["thread"] = collapse_ids( - _node_data["node"]["tid"], self.per_thread + _tid_dict["thread"] = cast( + str, collapse_ids(_node_data["node"]["tid"], self.per_thread) ) - _extra["pid"] = collapse_ids(_node_data["node"]["pid"], False) + _extra["pid"] = cast(str, collapse_ids(_node_data["node"]["pid"], False)) _extra["count"] = _node_data["node"]["inclusive"]["entry"]["laps"] # check if there are multiple threads @@ -598,7 +602,7 @@ def read(self) -> GraphFrame: if isinstance(self.input, dict): self.graph_dict = self.input # check if the input is a directory and get '.tree.json' files if true. - elif os.path.isdir(self.input): + elif isinstance(self.input, str) and os.path.isdir(self.input): tree_files = glob.glob(self.input + "/*.tree.json") for file in tree_files: # read all files that end with .tree.json. diff --git a/hatchet/util/colormaps.py b/hatchet/util/colormaps.py index c1036605..fd3b9d49 100644 --- a/hatchet/util/colormaps.py +++ b/hatchet/util/colormaps.py @@ -127,7 +127,7 @@ def get_colors(self, colormap: str, invert_colormap: bool) -> List[str]: self.colors = self.Spectral.copy() else: raise ValueError( - self.colormap + colormap + " is an incorrect colormap. Select one BrBG, PiYg, PRGn," + " PuOr, RdBu, RdGy, RdYlBu, RdYlGn, or Spectral." ) diff --git a/hatchet/util/dot.py b/hatchet/util/dot.py index 8811de48..8f9797fa 100644 --- a/hatchet/util/dot.py +++ b/hatchet/util/dot.py @@ -33,7 +33,7 @@ def trees_to_dot( all_edges = "" # call to_dot for each root in the graph - visited = [] + visited: List[Node] = [] for root in roots: (nodes, edges) = to_dot( root, dataframe, metric, name, rank, thread, threshold, visited @@ -57,20 +57,22 @@ def to_dot( visited: List[Node], ) -> Tuple[str, str]: """Write to graphviz dot format.""" - colormap = matplotlib.cm.Reds + # Tell mypy to ignore Reds here because mpl.cm is + # dynamically generated. So, mypy cannot discover that + # Reds exists + colormap = matplotlib.cm.Reds # type: ignore[attr-defined] min_time = dataframe[metric].min() max_time = dataframe[metric].max() def add_nodes_and_edges(hnode: Node) -> Tuple[str, str]: # set dataframe index based on if rank is a part of the index + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] = hnode if "rank" in dataframe.index.names and "thread" in dataframe.index.names: df_index = (hnode, rank, thread) elif "rank" in dataframe.index.names: df_index = (hnode, rank) elif "thread" in dataframe.index.names: df_index = (hnode, thread) - else: - df_index = hnode node_time = dataframe.loc[df_index, metric] node_name = dataframe.loc[df_index, name] diff --git a/hatchet/util/executable.py b/hatchet/util/executable.py index 2e6b42d3..0ff7f2e5 100644 --- a/hatchet/util/executable.py +++ b/hatchet/util/executable.py @@ -14,9 +14,9 @@ def which(executable: str) -> Optional[str]: executable (str): executable to search for """ path = os.environ.get("PATH", "/usr/sbin:/usr/bin:/sbin:/bin") - path = path.split(os.pathsep) + split_path = path.split(os.pathsep) - for directory in path: + for directory in split_path: exe = os.path.join(directory, executable) if os.path.isfile(exe) and os.access(exe, os.X_OK): return exe diff --git a/hatchet/util/profiler.py b/hatchet/util/profiler.py index ac8d4984..b98c5b51 100644 --- a/hatchet/util/profiler.py +++ b/hatchet/util/profiler.py @@ -11,10 +11,7 @@ from datetime import datetime -try: - from StringIO import StringIO # python2 -except ImportError: - from io import StringIO # python3 +from io import StringIO # python3 import pstats diff --git a/hatchet/util/timer.py b/hatchet/util/timer.py index f40f3043..1720d3c1 100644 --- a/hatchet/util/timer.py +++ b/hatchet/util/timer.py @@ -7,15 +7,16 @@ from contextlib import contextmanager from datetime import datetime, timedelta from io import StringIO +from typing import Optional class Timer(object): """Simple phase timer with a context manager.""" def __init__(self) -> None: - self._phase = None - self._start_time = None - self._times = OrderedDict() + self._phase: Optional[str] = None + self._start_time: Optional[datetime] = None + self._times: OrderedDict = OrderedDict() def start_phase(self, phase: str) -> timedelta: now = datetime.now() diff --git a/hatchet/writers/dataframe_writer.py b/hatchet/writers/dataframe_writer.py index 3adc37a0..d75b0aca 100644 --- a/hatchet/writers/dataframe_writer.py +++ b/hatchet/writers/dataframe_writer.py @@ -7,18 +7,7 @@ from hatchet.graphframe import GraphFrame import pandas as pd -from abc import abstractmethod - -# TODO The ABC class was introduced in Python 3.4. -# When support for earlier versions is (eventually) dropped, -# this entire "try-except" block can be reduced to: -# from abc import ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) +from abc import abstractmethod, ABC def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: diff --git a/pyproject.toml b/pyproject.toml index a1261f27..12214108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,14 @@ authors = [ license = "MIT" [tool.mypy] -exclude = "hatchet/tests" +exclude = [ + "hatchet/tests", + "hatchet/vis", + "hatchet/external/roundtrip", + "setup.py", +] +strict_optional = false +disable_error_code = "import-untyped" [tool.ruff] line-length = 88