Skip to content

Commit

Permalink
Fixes logic for handling string values of predicate_row_aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
ilumsden committed Nov 11, 2024
1 parent 86da99c commit 7ad292e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 39 deletions.
36 changes: 35 additions & 1 deletion hatchet/query/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: MIT

from itertools import groupby
from collections.abc import Iterable
import pandas as pd

from .errors import InvalidQueryFilter
Expand All @@ -14,6 +15,22 @@
from .string_dialect import parse_string_dialect


def _all_aggregator(pred_result):
if isinstance(pred_result, Iterable):
return all(pred_result)
elif isinstance(pred_result, pd.Series):
return pred_result.all()
return pred_result


def _any_aggregator(pred_result):
if isinstance(pred_result, Iterable):
return any(pred_result)
elif isinstance(pred_result, pd.Series):
return pred_result.any()
return pred_result


class QueryEngine:
"""Class for applying queries to GraphFrames."""

Expand All @@ -40,6 +57,20 @@ def apply(self, query, graph, dframe, predicate_row_aggregator):
aggregator = predicate_row_aggregator
if predicate_row_aggregator is None:
aggregator = query.default_aggregator
elif predicate_row_aggregator == "all":
aggregator = _all_aggregator
elif predicate_row_aggregator == "any":
aggregator = _any_aggregator
elif predicate_row_aggregator == "off":
if isinstance(dframe.index, pd.MultiIndex):
raise ValueError(
"'predicate_row_aggregator' cannot be 'off' when the DataFrame has a row multi-index"
)
aggregator = None
elif not callable(predicate_row_aggregator):
raise ValueError(
"Invalid value provided for 'predicate_row_aggregator'"
)
self.reset_cache()
matches = []
visited = set()
Expand Down Expand Up @@ -84,7 +115,10 @@ def _cache_node(self, node, query, dframe, predicate_row_aggregator):
else:
row = dframe.loc[node]
predicate_result = filter_func(row)
if not isinstance(predicate_result, bool):
if (
not isinstance(predicate_result, bool)
and predicate_row_aggregator is not None
):
predicate_result = predicate_row_aggregator(predicate_result)
if predicate_result:
matches.append(i)
Expand Down
67 changes: 29 additions & 38 deletions hatchet/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
ExclusiveDisjunctionQuery,
NegationQuery,
)
from hatchet.query.errors import MultiIndexModeMismatch


def test_construct_object_dialect():
Expand Down Expand Up @@ -512,9 +511,9 @@ def test_apply_indices(calc_pi_hpct_db):
],
]
matches = list(set().union(*matches))
query = ObjectQuery(path, predicate_row_aggregator="all")
query = ObjectQuery(path)
engine = QueryEngine()
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches)

gf.drop_index_levels()
assert engine.apply(query, gf.graph, gf.dataframe) == matches
Expand Down Expand Up @@ -588,7 +587,7 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db):
gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db))
root = gf.graph.roots[0]

query = ObjectQuery([("*", {"depth": "<= 2"})], predicate_row_aggregator="all")
query = ObjectQuery([("*", {"depth": "<= 2"})])
engine = QueryEngine()
matches = [
[root, root.children[0], root.children[0].children[0]],
Expand All @@ -599,22 +598,22 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db):
[root.children[0].children[1]],
]
matches = list(set().union(*matches))
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches)

query = ObjectQuery([("*", {"depth": 0})], predicate_row_aggregator="all")
query = ObjectQuery([("*", {"depth": 0})])
matches = [root]
assert engine.apply(query, gf.graph, gf.dataframe) == matches
assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches

with pytest.raises(InvalidQueryFilter):
query = ObjectQuery([{"depth": "hello"}], predicate_row_aggregator="all")
engine.apply(query, gf.graph, gf.dataframe)
query = ObjectQuery([{"depth": "hello"}])
engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")


def test_object_dialect_node_id_index_levels(calc_pi_hpct_db):
gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db))
root = gf.graph.roots[0]

query = ObjectQuery([("*", {"node_id": "<= 2"})], predicate_row_aggregator="all")
query = ObjectQuery([("*", {"node_id": "<= 2"})])
engine = QueryEngine()
matches = [
[root, root.children[0]],
Expand All @@ -624,15 +623,15 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db):
[root.children[0].children[0]],
]
matches = list(set().union(*matches))
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches)

query = ObjectQuery([("*", {"node_id": 0})], predicate_row_aggregator="all")
query = ObjectQuery([("*", {"node_id": 0})])
matches = [root]
assert engine.apply(query, gf.graph, gf.dataframe) == matches
assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches

with pytest.raises(InvalidQueryFilter):
query = ObjectQuery([{"node_id": "hello"}], predicate_row_aggregator="all")
engine.apply(query, gf.graph, gf.dataframe)
query = ObjectQuery([{"node_id": "hello"}])
engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")


def test_object_dialect_multi_condition_one_attribute(mock_graph_literal):
Expand Down Expand Up @@ -1283,7 +1282,7 @@ def test_object_dialect_all_mode(tau_profile_dir):
gf = GraphFrame.from_tau(tau_profile_dir)
engine = QueryEngine()
query = ObjectQuery(
[".", ("+", {"time (inc)": ">= 17983.0"})], predicate_row_aggregator="all"
[".", ("+", {"time (inc)": ">= 17983.0"})]
)
roots = gf.graph.roots
matches = [
Expand All @@ -1292,7 +1291,7 @@ def test_object_dialect_all_mode(tau_profile_dir):
roots[0].children[6].children[1],
roots[0].children[0],
]
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches)


def test_string_dialect_all_mode(tau_profile_dir):
Expand All @@ -1301,8 +1300,7 @@ def test_string_dialect_all_mode(tau_profile_dir):
query = StringQuery(
"""MATCH (".")->("+", p)
WHERE p."time (inc)" >= 17983.0
""",
predicate_row_aggregator="all",
"""
)
roots = gf.graph.roots
matches = [
Expand All @@ -1311,19 +1309,19 @@ def test_string_dialect_all_mode(tau_profile_dir):
roots[0].children[6].children[1],
roots[0].children[0],
]
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches)


def test_object_dialect_any_mode(tau_profile_dir):
gf = GraphFrame.from_tau(tau_profile_dir)
engine = QueryEngine()
query = ObjectQuery([{"time": "< 24.0"}], predicate_row_aggregator="any")
query = ObjectQuery([{"time": "< 24.0"}])
roots = gf.graph.roots
matches = [
roots[0].children[2],
roots[0].children[6].children[3],
]
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches)


def test_string_dialect_any_mode(tau_profile_dir):
Expand All @@ -1332,31 +1330,24 @@ def test_string_dialect_any_mode(tau_profile_dir):
query = StringQuery(
"""MATCH (".", p)
WHERE p."time" < 24.0
""",
predicate_row_aggregator="any",
"""
)
roots = gf.graph.roots
matches = [
roots[0].children[2],
roots[0].children[6].children[3],
]
assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches)
assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches)


def test_predicate_row_aggregator_assertion_error(tau_profile_dir):
with pytest.raises(AssertionError):
_ = ObjectQuery([".", ("*", {"name": "test"})], predicate_row_aggregator="foo")
with pytest.raises(AssertionError):
_ = StringQuery(
""" MATCH (".")->("*", p)
WHERE p."name" = "test"
""",
predicate_row_aggregator="foo",
)
gf = GraphFrame.from_tau(tau_profile_dir)
engine = QueryEngine()
query = ObjectQuery([".", ("*", {"name": "test"})])
with pytest.raises(ValueError):
engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="foo")
query = ObjectQuery(
[".", ("*", {"time (inc)": "> 17983.0"})], predicate_row_aggregator="off"
[".", ("*", {"time (inc)": "> 17983.0"})]
)
engine = QueryEngine()
with pytest.raises(MultiIndexModeMismatch):
engine.apply(query, gf.graph, gf.dataframe)
with pytest.raises(ValueError):
engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="off")

0 comments on commit 7ad292e

Please sign in to comment.