Skip to content

Commit

Permalink
more scaffolding in pipeline classes, some utility classmethods
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Nov 5, 2024
1 parent a4f9189 commit c1e81f5
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 16 deletions.
12 changes: 11 additions & 1 deletion miniscope_io/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
"""


class ConfigurationError(ValueError):
"""Base exception class for errors in configuration"""


class InvalidSDException(Exception):
"""
Raised when :class:`.io.SDCard` is used with a drive that doesn't have the
Expand Down Expand Up @@ -52,7 +56,13 @@ class DeviceOpenError(DeviceError):
"""


class DeviceConfigurationError(DeviceError):
class DeviceConfigurationError(DeviceError, ConfigurationError):
"""
Error while configuring a device
"""


class ConfigurationMismatchError(ConfigurationError):
"""
Mismatch between the fields in some config model and the fields in the model it is configuring
"""
166 changes: 151 additions & 15 deletions miniscope_io/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""

from abc import abstractmethod
from typing import ClassVar, Generic, TypeVar, Union
from typing import ClassVar, Final, Generic, Self, TypeVar, Union, final

from pydantic import Field

from miniscope_io.models.models import PipelineModel
from miniscope_io.exceptions import ConfigurationMismatchError
from miniscope_io.models.models import MiniscopeConfig, PipelineModel

T = TypeVar("T")
"""
Expand All @@ -19,15 +20,94 @@
"""


class Node(PipelineModel):
class NodeConfig(MiniscopeConfig):
"""Configuration for a single processing node"""

type_: str = Field(..., alias="type")
"""
Shortname of the type of node this configuration is for.
Subclasses should override this with a default.
"""

id: str
"""The unique identifier of the node"""
inputs: list[str] = Field(default_factory=list)
"""List of Node IDs to be used as input"""
outputs: list[str] = Field(default_factory=list)
"""List of Node IDs to be used as output"""


class PipelineConfig(MiniscopeConfig):
"""
Configuration for the nodes within a pipeline
"""

nodes: dict[str, NodeConfig] = Field(default_factory=dict)
"""The nodes that this pipeline configures"""


class Node(PipelineModel, Generic[T, U]):
"""A node within a processing pipeline"""

type_: ClassVar[str]
"""
Shortname for this type of node to match configs to node types
"""

id: str
"""Unique identifier of the node"""
config: NodeConfig = NodeConfig()

input_type: ClassVar[type[T]]
inputs: dict[str, Union["Source", "ProcessingNode"]] = Field(default_factory=dict)
output_type: ClassVar[type[U]]
outputs: dict[str, Union["Sink", "ProcessingNode"]] = Field(default_factory=dict)

@abstractmethod
def start(self) -> None:
"""
Start producing, processing, or receiving data
"""

@abstractmethod
def stop(self) -> None:
"""
Stop producing, processing, or receiving data
"""

@classmethod
def from_config(cls, config: NodeConfig) -> Self:
"""
Create a node from its config
"""
return cls(id=config.id, config=config)

@classmethod
@final
def node_types(cls) -> dict[str, type["Node"]]:
"""
Map of all imported :attr:`.Node.type_` names to node classes
"""
node_types = {}
to_check = cls.__subclasses__()
while to_check:
node = to_check.pop()
if node.type_ in node_types:
raise ValueError(
f"Repeated node type_ identifier: {node.type_}, found in:\n"
f"- {node_types[node.type_]}\n- {node}"
)
node_types[node.type_] = node
to_check.extend(node.__subclasses__())
return node_types


class Source(Node, Generic[U]):
"""A source of data in a processing pipeline"""

output_type: ClassVar[type[U]]
outputs: list[Union["Sink", "ProcessingNode"]] = Field(default_factory=list)
inputs: Final[None] = None
input_type: ClassVar[None] = None

@abstractmethod
def process(self) -> U:
Expand All @@ -47,8 +127,8 @@ def process(self) -> U:
class Sink(Node, Generic[T]):
"""A sink of data in a processing pipeline"""

input_type: ClassVar[type[T]]
inputs: list[Union["Source", "ProcessingNode"]] = Field(default_factory=list)
output_type: ClassVar[None] = None
outputs: Final[None] = None

@abstractmethod
def process(self, data: T) -> None:
Expand All @@ -68,11 +148,6 @@ class ProcessingNode(Node, Generic[T, U]):
An intermediate processing node that transforms some input to output
"""

input_type: ClassVar[type[T]]
inputs: list[Union["Source", "ProcessingNode"]] = Field(default_factory=list)
output_type: ClassVar[type[U]]
outputs: list[Union["Sink", "ProcessingNode"]] = Field(default_factory=list)

@abstractmethod
def process(self, data: T) -> U:
"""
Expand All @@ -92,9 +167,25 @@ class Pipeline(PipelineModel):
A graph of nodes transforming some input source(s) to some output sink(s)
"""

sources: list["Source"] = Field(default_factory=list)
processing_nodes: list["ProcessingNode"] = Field(default_factory=list)
sinks: list["Sink"] = Field(default_factory=list)
nodes: dict[str, Node] = Field(default_factory=dict)
"""
Dictionary mapping all nodes from their ID to the instantiated node.
"""

@property
def sources(self) -> dict[str, "Source"]:
"""All :class:`.Source` nodes in the processing graph"""
return {k: v for k, v in self.nodes.items() if isinstance(v, Source)}

@property
def processing_nodes(self) -> dict[str, "ProcessingNode"]:
"""All :class:`.ProcessingNode` s in the processing graph"""
return {k: v for k, v in self.nodes.items() if isinstance(v, ProcessingNode)}

@property
def sinks(self) -> dict[str, "Sink"]:
"""All :class:`.Sink` nodes in the processing graph"""
return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)}

@abstractmethod
def process(self) -> None:
Expand All @@ -106,3 +197,48 @@ def process(self) -> None:
result/status object, as any data intended to be received/processed by
downstream objects should be accessed via a :class:`.Sink` .
"""

@abstractmethod
def start(self) -> None:
"""
Start processing data with the pipeline graph
"""

@abstractmethod
def stop(self) -> None:
"""
Stop processing data with the pipeline graph
"""

@classmethod
def from_config(cls, config: PipelineConfig) -> Self:
"""
Instantiate a pipeline model from its configuration
"""
types = Node.node_types()

nodes = {k: types[v.type_].from_config(v) for k, v in config.nodes.items()}
nodes = connect_nodes(nodes)
return cls(nodes=nodes)


def connect_nodes(nodes: dict[str, Node]) -> dict[str, Node]:
"""
Provide references to instantiated nodes
"""

for node in nodes.values():
if node.config.inputs and node.inputs is None:
raise ConfigurationMismatchError(
"inputs found in node configuration, but node type allows no inputs!\n"
f"node: {node.model_dump()}"
)
if node.config.outputs and not hasattr(node, "outputs"):
raise ConfigurationMismatchError(
"outputs found in node configuration, but node type allows no outputs!\n"
f"node: {node.model_dump()}"
)

node.inputs.update({id: nodes[id] for id in node.config.inputs})
node.outputs.update({id: nodes[id] for id in node.config.outputs})
return nodes

0 comments on commit c1e81f5

Please sign in to comment.