Skip to content

Commit

Permalink
feat: Support fanning out parent record into multiple child syncs
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jan 10, 2024
1 parent 2fe3e8d commit c99f492
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 6 deletions.
29 changes: 23 additions & 6 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,17 +1026,18 @@ def _process_record(
partition_context: The partition context.
"""
partition_context = partition_context or {}
child_context = copy.copy(
self.get_child_context(record=record, context=child_context),
)
for key, val in partition_context.items():
# Add state context to records if not already present
if key not in record:
record[key] = val

# Sync children, except when primary mapper filters out the record
if self.stream_maps[0].get_filter_result(record):
self._sync_children(child_context)
for context in self.generate_child_contexts(
record=record,
context=child_context,
):
# Sync children, except when primary mapper filters out the record
if self.stream_maps[0].get_filter_result(record):
self._sync_children(copy.copy(context))

def _sync_records( # noqa: C901
self,
Expand Down Expand Up @@ -1289,6 +1290,22 @@ def get_child_context(self, record: dict, context: dict | None) -> dict | None:

return context or record

def generate_child_contexts(
self,
record: dict,
context: dict | None,
) -> t.Iterable[dict | None]:
"""Generate child contexts.
Args:
record: Individual record in the stream.
context: Stream partition or context dictionary.
Yields:
A child context for each child stream.
"""
yield self.get_child_context(record=record, context=context)

# Abstract Methods

@abc.abstractmethod
Expand Down
99 changes: 99 additions & 0 deletions tests/core/test_parent_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,102 @@ def test_child_deselected_parent(tap_with_deselected_parent: MyTap):
assert all(msg["type"] == SingerMessageType.RECORD for msg in child_record_messages)
assert all(msg["stream"] == child_stream.name for msg in child_record_messages)
assert all("pid" in msg["record"] for msg in child_record_messages)


def test_one_parent_many_children(tap: MyTap):
"""Test tap output with parent stream deselected."""

class ParentMany(Stream):
"""A parent stream."""

name = "parent_many"
schema: t.ClassVar[dict] = {
"type": "object",
"properties": {
"id": {"type": "integer"},
"children": {"type": "array", "items": {"type": "integer"}},
},
}

def get_records(
self,
context: dict | None, # noqa: ARG002
) -> t.Iterable[dict | tuple[dict, dict | None]]:
yield {"id": "1", "children": [1, 2, 3]}

def generate_child_contexts(
self,
record: dict,
context: dict | None, # noqa: ARG002
) -> t.Iterable[dict | None]:
for child_id in record["children"]:
yield {"child_id": child_id, "pid": record["id"]}

class ChildMany(Stream):
"""A child stream."""

name = "child_many"
schema: t.ClassVar[dict] = {
"type": "object",
"properties": {
"id": {"type": "integer"},
"pid": {"type": "integer"},
},
}
parent_stream_type = ParentMany

def get_records(self, context: dict | None):
"""Get dummy records."""
yield {
"id": context["child_id"],
"composite_id": f"{context['pid']}-{context['child_id']}",
}

class MyTapMany(Tap):
"""A tap with streams having a parent-child relationship."""

name = "my-tap-many"

def discover_streams(self):
"""Discover streams."""
return [
ParentMany(self),
ChildMany(self),
]

tap = MyTapMany()
parent_stream = tap.streams["parent_many"]
child_stream = tap.streams["child_many"]

messages = _get_messages(tap)

# Parent schema is emitted
assert messages[1]
assert messages[1]["type"] == SingerMessageType.SCHEMA
assert messages[1]["stream"] == parent_stream.name
assert messages[1]["schema"] == parent_stream.schema

# Child schemas are emitted
schema_messages = messages[2:9:3]
assert schema_messages
assert all(msg["type"] == SingerMessageType.SCHEMA for msg in schema_messages)
assert all(msg["stream"] == child_stream.name for msg in schema_messages)
assert all(msg["schema"] == child_stream.schema for msg in schema_messages)

# Child records are emitted
child_record_messages = messages[3:10:3]
assert child_record_messages
assert all(msg["type"] == SingerMessageType.RECORD for msg in child_record_messages)
assert all(msg["stream"] == child_stream.name for msg in child_record_messages)
assert all("pid" in msg["record"] for msg in child_record_messages)

# State messages are emitted
state_messages = messages[4:11:3]
assert state_messages
assert all(msg["type"] == SingerMessageType.STATE for msg in state_messages)

# Parent record is emitted
assert messages[11]
assert messages[11]["type"] == SingerMessageType.RECORD

raise AssertionError

0 comments on commit c99f492

Please sign in to comment.