diff --git a/core/filters.py b/core/filters.py index c151fec..6472ec7 100644 --- a/core/filters.py +++ b/core/filters.py @@ -4,7 +4,7 @@ from strawberry import auto from typing import Optional from strawberry_django.filters import FilterLookup - +import strawberry_django print("Test") @@ -114,6 +114,33 @@ class ViewFilter: is_global: auto provenance: ProvenanceFilter | None +@strawberry.django.filter(models.PixelLabel) +class PixelLabelFilter: + value: float | None = None + view: strawberry.ID | None = None + entity_kind: strawberry.ID | None = None + entity: strawberry.ID | None = None + + def filter_value(self, queryset, info): + if self.value is None: + return queryset + return queryset.filter(value=self.value) + + def filter_view(self, queryset, info): + if self.view is None: + return queryset + return queryset.filter(view_id=self.view) + + def filter_entity_kind(self, queryset, info): + if self.entity_kind is None: + return queryset + return queryset.filter(entity_entity_kind_id=self.entity_kind) + + def filter_entity(self, queryset, info): + if self.entity is None: + return queryset + return queryset.filter(entity_id=self.entity) + @strawberry.django.filter(models.AffineTransformationView) class AffineTransformationViewFilter(ViewFilter): @@ -132,6 +159,9 @@ class TimepointViewFilter(ViewFilter): ms_since_start: auto index_since_start: auto +@strawberry.django.filter(models.PixelView) +class PixelViewFilter(ViewFilter): + pass @strawberry.django.filter(models.OpticsView) class OpticsViewFilter(ViewFilter): @@ -179,7 +209,7 @@ def filter_ids(self, queryset, info): return queryset.filter(id__in=self.ids) -@strawberry.django.filter(models.Image) +@strawberry_django.filter(models.Image) class ImageFilter: name: Optional[FilterLookup[str]] ids: list[strawberry.ID] | None @@ -187,6 +217,7 @@ class ImageFilter: dataset: DatasetFilter | None transformation_views: AffineTransformationViewFilter | None timepoint_views: TimepointViewFilter | None + not_derived: bool | None = None provenance: ProvenanceFilter | None @@ -194,6 +225,12 @@ def filter_ids(self, queryset, info): if self.ids is None: return queryset return queryset.filter(id__in=self.ids) + + def filter_not_derived(self, queryset, info): + print("Filtering not derived") + if self.not_derived is None: + return queryset + return queryset.filter(origins=None) @strawberry.django.filter(models.ROI) @@ -377,6 +414,11 @@ class SpecimenFilter(IDFilterMixin, SearchFilterMixin): class ProtocolFilter(IDFilterMixin, SearchFilterMixin): id: auto +@strawberry.django.filter(models.ProtocolStep) +class ProtocolStepFilter(IDFilterMixin, SearchFilterMixin): + id: auto + + @strawberry.django.filter(models.ProtocolStepMapping) class ProtocolStepMappingFilter(IDFilterMixin, SearchFilterMixin): id: auto diff --git a/core/migrations/0003_historicalprotocolstep_plate_children_and_more.py b/core/migrations/0003_historicalprotocolstep_plate_children_and_more.py new file mode 100644 index 0000000..69d1e90 --- /dev/null +++ b/core/migrations/0003_historicalprotocolstep_plate_children_and_more.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.8 on 2024-08-16 13:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("core", "0002_alter_protocolstepmapping_t"), + ] + + operations = [ + migrations.AddField( + model_name="historicalprotocolstep", + name="plate_children", + field=models.JSONField( + blank=True, + default=list, + help_text="The children of the slate", + null=True, + ), + ), + migrations.AddField( + model_name="protocolstep", + name="plate_children", + field=models.JSONField( + blank=True, + default=list, + help_text="The children of the slate", + null=True, + ), + ), + ] diff --git a/core/migrations/0004_remove_entity_parent_remove_historicalentity_parent.py b/core/migrations/0004_remove_entity_parent_remove_historicalentity_parent.py new file mode 100644 index 0000000..fd629f5 --- /dev/null +++ b/core/migrations/0004_remove_entity_parent_remove_historicalentity_parent.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.8 on 2024-08-18 15:16 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("core", "0003_historicalprotocolstep_plate_children_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="entity", + name="parent", + ), + migrations.RemoveField( + model_name="historicalentity", + name="parent", + ), + ] diff --git a/core/migrations/0005_entityrelation_only_one_relation_per_kind.py b/core/migrations/0005_entityrelation_only_one_relation_per_kind.py new file mode 100644 index 0000000..bfb456f --- /dev/null +++ b/core/migrations/0005_entityrelation_only_one_relation_per_kind.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.8 on 2024-08-18 19:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("core", "0004_remove_entity_parent_remove_historicalentity_parent"), + ] + + operations = [ + migrations.AddConstraint( + model_name="entityrelation", + constraint=models.UniqueConstraint( + fields=("left", "kind", "right"), name="only_one_relation_per_kind" + ), + ), + ] diff --git a/core/models.py b/core/models.py index 2b02647..c3b8fae 100644 --- a/core/models.py +++ b/core/models.py @@ -1,3 +1,4 @@ +import uuid from django.db import models from django.contrib.auth import get_user_model from django.forms import FileField @@ -342,6 +343,10 @@ class ProtocolStep(models.Model): related_name="entities", help_text="The reagents that were used in this step (you can specifiy properties of the reagents in the entity)", ) + plate_children = models.JSONField( + help_text="The children of the slate", null=True, blank=True, + default=list + ) history = HistoryField() @@ -1127,6 +1132,15 @@ class Meta: def __str__(self) -> str: return f"{self.label} in {self.ontology}" + def create_entity(self, group, name: str = None, instance_kind: str = None, metrics: dict = None) -> "Entity": + return Entity.objects.create( + name=name or str(uuid.uuid4()), + group=group, + kind=self, + instance_kind=instance_kind, + metrics=metrics or {} + ) + class EntityGroup(models.Model): """An EntityGroup is a collection of Entities. @@ -1176,15 +1190,6 @@ class Entity(models.Model): related_name="entities", help_text="The group this entity belongs to", ) - - parent = models.ForeignKey( - "self", - on_delete=models.CASCADE, - related_name="parts", - null=True, - blank=True, - help_text="The entity this entity is part of", - ) kind = models.ForeignKey( EntityKind, on_delete=models.CASCADE, @@ -1245,6 +1250,15 @@ class EntityRelation(models.Model): help_text="Associated metrics this relation", ) + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["left", "kind", "right"], + name="only_one_relation_per_kind", + ) + ] + + class EntityMetric(models.Model): kind = models.OneToOneField( diff --git a/core/mutations/entity.py b/core/mutations/entity.py index 0ce5942..0457cce 100644 --- a/core/mutations/entity.py +++ b/core/mutations/entity.py @@ -41,7 +41,6 @@ def create_entity( group=group, kind=input_kind, defaults=dict( - parent=models.Entity.objects.get(id=input.parent) if input.parent else None, name=input.name or uuid.uuid4().hex, instance_kind=input.instance_kind, ) diff --git a/core/mutations/image.py b/core/mutations/image.py index 78f1590..31a86bc 100644 --- a/core/mutations/image.py +++ b/core/mutations/image.py @@ -13,7 +13,9 @@ PartialSpecimenViewInput, PartialAcquisitionViewInput, PartialAffineTransformationViewInput, + PartialPixelViewInput, PartialScaleViewInput, + _create_pixel_view_from_partial, view_kwargs_from_input, ) from django.conf import settings @@ -204,7 +206,7 @@ class FromArrayLikeInput: channel_views: list[PartialChannelViewInput] | None = None transformation_views: list[PartialAffineTransformationViewInput] | None = None acquisition_views: list[PartialAcquisitionViewInput] | None = None - label_views: list[PartialLabelViewInput] | None = None + pixel_views: list[PartialPixelViewInput] | None = None specimen_views: list[PartialSpecimenViewInput] | None = None rgb_views: list[PartialRGBViewInput] | None = None timepoint_views: list[PartialTimepointViewInput] | None = None @@ -280,16 +282,6 @@ def from_array_like( **view_kwargs_from_input(specimenview), ) - if input.label_views is not None: - for labelview in input.label_views: - models.LabelView.objects.create( - image=image, - fluorophore_id=labelview.fluorophore, - primary_antibody_id=labelview.primary_antibody, - secondary_antibody_id=labelview.secondary_antibody, - **view_kwargs_from_input(labelview), - ) - if input.scale_views is not None: for scaleview in input.scale_views: models.ScaleView.objects.create( @@ -370,6 +362,10 @@ def from_array_like( **view_kwargs_from_input(transformationview), ) + if input.pixel_views: + for pixelview in input.pixel_views: + _create_pixel_view_from_partial(image, pixelview) + return image diff --git a/core/mutations/protocol_step.py b/core/mutations/protocol_step.py index a112e98..8b289cd 100644 --- a/core/mutations/protocol_step.py +++ b/core/mutations/protocol_step.py @@ -1,10 +1,25 @@ from kante.types import Info import strawberry -from core import types, models +from core import types, models, scalars import uuid +@strawberry.input +class PlateChildInput: + id: strawberry.ID | None = None + type: str | None = None + text: str | None = None + children: list["PlateChildInput"] | None = None + value: str | None = None + color: str | None = None + fontSize: str | None = None + backgroundColor: str | None = None + bold: bool | None = None + italic: bool | None = None + underline: bool | None = None + + @strawberry.input @@ -13,6 +28,19 @@ class ProtocolStepInput: reagents: list[strawberry.ID] | None = None kind: strawberry.ID description: str | None = None + plate_children: list[PlateChildInput] | None = None + + +@strawberry.input +class UpdateProtocolStepInput: + name: str | None = None + id: strawberry.ID + reagents: list[strawberry.ID] | None = None + kind: strawberry.ID | None = None + description: str | None = None + plate_children: list[PlateChildInput] | None = None + + @strawberry.input class MapProtocolStepInput: @@ -38,11 +66,12 @@ def create_protocol_step( - step, _ = models.ProtocolStep.objects.update_or_create( + step, _ = models.ProtocolStep.objects.get_or_create( name=input.name, defaults=dict( kind=input_kind, description=input.description or "", + plate_children=input.plate_children or [], ), ) @@ -64,6 +93,37 @@ def map_protocol_step(info: Info, input: MapProtocolStepInput) -> types.Protocol return mapping +def child_to_str(child): + if child.get("children", []) is None: + return " ".join([child_to_str(c) for c in child["children"]]), + else: + return child.get("value", child.get("text", "")) or "" + + +def plate_children_to_str(children): + return " ".join([child_to_str(c) for c in children]) + + +def update_protocol_step( + info: Info, + input: UpdateProtocolStepInput, +) -> types.ProtocolStep: + step = models.ProtocolStep.objects.get(id=input.id) + step.name = input.name if input.name else step.name + step.description = input.description if input.description else plate_children_to_str([strawberry.asdict(i) for i in input.plate_children]) + step.plate_children = [strawberry.asdict(i) for i in input.plate_children] if input.plate_children else step.plate_children + step.kind = models.EntityKind.objects.get(id=input.kind) if input.kind else step.kind + + + if input.reagents: + + step.reagents.clear() + for reagent in input.reagents: + step.reagents.add(models.Entity.objects.get(id=reagent)) + + step.save() + return step + def delete_protocol_step( info: Info, diff --git a/core/mutations/view.py b/core/mutations/view.py index 1ab5960..b3fd919 100644 --- a/core/mutations/view.py +++ b/core/mutations/view.py @@ -1,3 +1,4 @@ +from typing import List from kante.types import Info import strawberry from core import types, models, scalars, enums @@ -81,6 +82,21 @@ class PartialScaleViewInput(ViewInput): scale_c: float | None = None +@strawberry.input() +class RangePixelLabel: + group: ID | None = None + entity_kind: ID + min: int + max: int + +@strawberry_django.input(models.PixelView) +class PartialPixelViewInput(ViewInput): + linked_view: ID | None = None + range_labels: List[RangePixelLabel] | None = None + + + + @strawberry_django.input(models.SpecimenView) @@ -154,6 +170,10 @@ class SpecimenViewInput(PartialSpecimenViewInput): image: ID +@strawberry_django.input(models.PixelView) +class PixelViewInput(PartialPixelViewInput): + image: ID + def view_kwargs_from_input(input: ChannelViewInput) -> dict: is_global = all( x is None @@ -268,6 +288,22 @@ def create_rgb_view( ) return view +def create_specimen_view( + info: Info, + input: SpecimenViewInput, +) -> types.SpecimenView: + image = models.Image.objects.get(id=input.image) + + view = models.SpecimenView.objects.create( + image=image, + specimen=models.Specimen.objects.get(id=input.specimen) if input.specimen else None, + step=models.ProtocolStep.objects.get(id=input.step) + if input.step + else None + **view_kwargs_from_input(input), + ) + return view + def delete_rgb_view( info: Info, @@ -425,3 +461,43 @@ def delete_optics_view( item = models.OpticsView.objects.get(id=input.id) item.delete() return input.id + + +def _create_pixel_view_from_partial(image, input: PartialPixelViewInput) -> types.PixelView: + view = models.PixelView.objects.create( + image=image, + **view_kwargs_from_input(input), + ) + + + + if input.range_labels: + for range_label in input.range_labels: + + if range_label.group: + group = models.EntityGroup.objects.get(id=input.group) + else: + group, _ = models.EntityGroup.objects.get_or_create(name="All entitites") + + for i in range(range_label.min, range_label.max + 1): + x = models.EntityKind.objects.get(id=range_label.entity_kind) + + models.PixelLabel.objects.create( + view=view, + entity=x.create_entity(group), + value=i, + ) + + return view + + + +def create_pixel_view( + info: Info, + input: PixelViewInput, +) -> types.PixelView: + image = models.Image.objects.get(id=input.image) + return _create_pixel_view_from_partial(image, input) + + + diff --git a/core/queries/__init__.py b/core/queries/__init__.py index b21af5c..95fefe5 100644 --- a/core/queries/__init__.py +++ b/core/queries/__init__.py @@ -1,2 +1,3 @@ from .image import * -from .knowledge_graph import * \ No newline at end of file +from .knowledge_graph import * +from .entity_graph import * \ No newline at end of file diff --git a/core/queries/entity_graph.py b/core/queries/entity_graph.py new file mode 100644 index 0000000..2cba3e2 --- /dev/null +++ b/core/queries/entity_graph.py @@ -0,0 +1,109 @@ +import strawberry +from core import models + + +@strawberry.type +class EntityNodeMetric: + data_kind: str + kind: str + metric_id: str + value: str + + +@strawberry.type +class EntityNode: + id: str + is_root: bool = False + subtitle: str + label: str + metrics: list[EntityNodeMetric] + + +@strawberry.type +class EntityRelationEdge: + id: str + label: str + source: str + target: str + metrics: list[EntityNodeMetric] + +@strawberry.type +class EntityGraph: + nodes: list[EntityNode] + edges: list[EntityRelationEdge] + + + +def entity_graph(id: strawberry.ID) -> EntityGraph: + """ + Query the knowledge graph for information about a given entity. + + Args: + query: The entity to search for in the knowledge graph. + + Returns: + A dictionary containing information about the entity. + """ + + nodes = [] + edges = [] + + entity = models.Entity.objects.get(id=id) + + + def parse_entity(entity, is_root=False): + + node = EntityNode(id=entity.id, subtitle=entity.name, metrics=[], label=entity.kind.label, is_root=is_root) + + metric_map = {} + + relation_metric_map = {} + + for key, value in entity.metrics.items(): + + if key not in metric_map: + metric_map[key] = models.EntityMetric.objects.get(id=key) + + metric = metric_map[key] + node.metrics.append(EntityNodeMetric(data_kind=metric.data_kind, kind=metric.kind.label, value=value, metric_id=key)) + + nodes.append(node) + + outgoing_relations = [] + first_partners_nodes = [] + + + for entity_relation in models.EntityRelation.objects.prefetch_related('kind__kind', "right__kind", "left__kind").filter(left=entity): + outgoing_relations.append(entity_relation) + edge = EntityRelationEdge(id=entity_relation.id, label=entity_relation.kind.kind.label, source=entity.id, target=entity_relation.right.id, metrics=[]) + first_partners_nodes.append(entity_relation.right) + + + node = EntityNode(id=entity_relation.right.id, label=entity_relation.right.kind.label, subtitle=entity_relation.right.name, metrics=[]) + + if node.id not in [n.id for n in nodes]: + nodes.append(node) + if edge.id not in [n.id for n in edges]: + edges.append(edge) + + + for entity_relation in models.EntityRelation.objects.prefetch_related('kind__kind', "right__kind", "left__kind").filter(right=entity): + outgoing_relations.append(entity_relation) + edge = EntityRelationEdge(id=entity_relation.id, label=entity_relation.kind.kind.label, source=entity_relation.left.id, target=entity.id, metrics=[]) + first_partners_nodes.append(entity_relation.left) + + + node = EntityNode(id=entity_relation.left.id, label=entity_relation.left.kind.label, subtitle=entity_relation.left.name, metrics=[]) + + if node.id not in [n.id for n in nodes]: + nodes.append(node) + if edge.id not in [n.id for n in edges]: + edges.append(edge) + + + + + parse_entity(entity, is_root=True) + + + return EntityGraph(nodes=nodes, edges=edges) \ No newline at end of file diff --git a/core/queries/knowledge_graph.py b/core/queries/knowledge_graph.py index 9d589f1..82de916 100644 --- a/core/queries/knowledge_graph.py +++ b/core/queries/knowledge_graph.py @@ -3,30 +3,30 @@ @strawberry.type -class EntityNodeMetric: +class EntityKindNodeMetric: data_kind: str kind: str @strawberry.type -class EntityNode: +class EntityKindNode: id: str label: str - metrics: list[EntityNodeMetric] + metrics: list[EntityKindNodeMetric] @strawberry.type -class EntityRelationEdge: +class EntityKindRelationEdge: id: str label: str source: str target: str - metrics: list[EntityNodeMetric] + metrics: list[EntityKindNodeMetric] @strawberry.type class KnowledgeGraph: - nodes: list[EntityNode] - edges: list[EntityRelationEdge] + nodes: list[EntityKindNode] + edges: list[EntityKindRelationEdge] @@ -49,19 +49,19 @@ def knowledge_graph(id: strawberry.ID) -> KnowledgeGraph: # Get the entity kind for entity_kind in entity_kinds: - node = EntityNode(id=entity_kind.id, label=entity_kind.label, metrics=[]) + node = EntityKindNode(id=entity_kind.id, label=entity_kind.label, metrics=[]) for metric in models.EntityMetric.objects.filter(kind=entity_kind): - node.metrics.append(EntityNodeMetric(data_kind=metric.data_kind, kind=metric.kind.label)) + node.metrics.append(EntityKindNodeMetric(data_kind=metric.data_kind, kind=metric.kind.label)) nodes.append(node) for entity_relation_kind in models.EntityRelationKind.objects.filter(left_kind__in=entity_kinds, right_kind__in=entity_kinds): - edge = EntityRelationEdge(id=entity_relation_kind.id, label=entity_relation_kind.kind.label, source=entity_relation_kind.left_kind.id, target=entity_relation_kind.right_kind.id, metrics=[]) + edge = EntityKindRelationEdge(id=entity_relation_kind.id, label=entity_relation_kind.kind.label, source=entity_relation_kind.left_kind.id, target=entity_relation_kind.right_kind.id, metrics=[]) for metric in models.RelationMetric.objects.filter(kind=entity_relation_kind.kind): - edge.metrics.append(EntityNodeMetric(data_kind=metric.data_kind, kind=metric.kind.label)) + edge.metrics.append(EntityKindNodeMetric(data_kind=metric.data_kind, kind=metric.kind.label)) edges.append(edge) diff --git a/core/scalars.py b/core/scalars.py index d13efff..c0dee97 100644 --- a/core/scalars.py +++ b/core/scalars.py @@ -17,6 +17,13 @@ parse_value=lambda v: v, ) +UntypedPlateChild = strawberry.scalar( + NewType("UntypedPlateChild", object), + description="The `UntypedPlateChild` scalar type represents a plate child", + serialize=lambda v: v, + parse_value=lambda v: v, +) + FileLike = strawberry.scalar( NewType("FileLike", str), @@ -125,4 +132,5 @@ description="The `MetricMap` scalar type represents a matrix values as specified by", serialize=lambda v: v, parse_value=lambda v: v, -) \ No newline at end of file +) + diff --git a/core/types.py b/core/types.py index eb24882..604f77a 100644 --- a/core/types.py +++ b/core/types.py @@ -281,7 +281,7 @@ class Specimen: @strawberry.django.field() def label(self, info: Info) -> str: - return f"{self.entity.name} subjected to {self.protocol.name}" + return f"{self.entity.name} in {self.protocol.name}" @strawberry_django.type(models.Experiment, filters=filters.ExperimentFilter, pagination=True) @@ -313,7 +313,11 @@ class ProtocolStepMapping: step: "ProtocolStep" -@strawberry_django.type(models.ProtocolStep, filters=filters.ProtocolFilter, pagination=True) + + + + +@strawberry_django.type(models.ProtocolStep, filters=filters.ProtocolStepFilter, pagination=True) class ProtocolStep: id: auto name: str @@ -326,6 +330,10 @@ class ProtocolStep: mappings: List["ProtocolStepMapping"] views: List["SpecimenView"] + @strawberry.django.field() + def plate_children(self, info) -> List[scalars.UntypedPlateChild]: + return self.plate_children if self.plate_children else [{"id": 1, "type": "p", "children": [{"text": self.description or "No description"}]}] + @strawberry_django.type(models.Table, filters=filters.TableFilter, pagination=True) @@ -886,6 +894,22 @@ class SpecimenView(View): specimen: Specimen step: ProtocolStep | None = None +@strawberry_django.type( + models.PixelView, filters=filters.PixelViewFilter, pagination=True +) +class PixelView(View): + id: auto + labels: list["PixelLabel"] + + +@strawberry_django.type( + models.PixelLabel, filters=filters.PixelLabelFilter, pagination=True +) +class PixelLabel: + id: auto + view: PixelView + value: int + entity: "Entity" diff --git a/mikro_server/schema.py b/mikro_server/schema.py index c59eb51..2d9e874 100644 --- a/mikro_server/schema.py +++ b/mikro_server/schema.py @@ -71,6 +71,7 @@ class Query: specimen_views: list[types.SpecimenView] = strawberry_django.field() knowledge_graph = strawberry_django.field(resolver=queries.knowledge_graph) + entity_graph = strawberry_django.field(resolver=queries.entity_graph) tables: list[types.Table] = strawberry_django.field() @@ -383,6 +384,9 @@ class Mutation: delete_protocol_step = strawberry_django.mutation( resolver=mutations.delete_protocol_step, ) + update_protocol_step = strawberry_django.mutation( + resolver=mutations.update_protocol_step, + ) @@ -490,6 +494,9 @@ class Mutation: create_channel_view = strawberry_django.mutation( resolver=mutations.create_channel_view ) + create_specimen_view = strawberry_django.mutation( + resolver=mutations.create_specimen_view + ) create_well_position_view = strawberry_django.mutation( resolver=mutations.create_well_position_view )