Skip to content

Commit

Permalink
Merge pull request #230 from topoteretes/COG-533-pydantic-unit-tests
Browse files Browse the repository at this point in the history
Cog 533 pydantic unit tests
  • Loading branch information
0xideas authored Nov 19, 2024
2 parents 1dd07cd + ab1328d commit c3757cc
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 113 deletions.
149 changes: 82 additions & 67 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from datetime import datetime, timezone

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model

def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):

def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):

if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}

nodes = []
edges = []

Expand All @@ -12,87 +20,94 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
for field_name, field_value in data_point:
if field_name == "_metadata":
continue

if isinstance(field_value, DataPoint):
elif isinstance(field_value, DataPoint):
excluded_properties.add(field_name)

property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)

elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name)

for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue

data_point_properties[field_name] = field_value
n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value

SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
include_fields={
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
exclude_fields=excluded_properties,
)

if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
nodes.append(SimpleDataPointModel(**data_point_properties))

return nodes, edges


def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
):

property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges, added_nodes, added_edges)


def get_own_properties(property_nodes, property_edges):
own_properties = []

Expand Down
44 changes: 28 additions & 16 deletions cognee/modules/graph/utils/get_model_instance_from_graph.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
from typing import Callable

from pydantic_core import PydanticUndefined

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}

for node in nodes:
node_map[node.id] = node

for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")
edge_type = edge_metadata.get("type", "default")

if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]

node_map[source_node_id] = NewModel(**source_node_dict)
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)

return node_map[entity_id]
4 changes: 3 additions & 1 deletion cognee/modules/storage/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
**include_fields
}

return create_model(model.__name__, **final_fields)
model = create_model(model.__name__, **final_fields)
model.model_rebuild()
return model

def get_own_properties(data_point: DataPoint):
properties = {}
Expand Down
18 changes: 3 additions & 15 deletions cognee/tests/unit/interfaces/graph/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Optional

import pytest

from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)


class CarTypeName(Enum):
Expand Down Expand Up @@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"])


@pytest.fixture(scope="session")
def graph_outputs():
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
Expand All @@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)

car, person = nodes[0], nodes[1]
edge = edges[0]

parsed_person = get_model_instance_from_graph(nodes, edges, "boris")

return (car, person, edge, parsed_person)
return boris
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import warnings

import pytest

from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)


@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys

if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)

n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons

nodes, edges = get_graph_from_model(society)

assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"

assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)
Loading

0 comments on commit c3757cc

Please sign in to comment.