Skip to content

Commit

Permalink
adding entity graph
Browse files Browse the repository at this point in the history
  • Loading branch information
jhnnsrs committed Aug 18, 2024
1 parent df45448 commit 0c889a8
Show file tree
Hide file tree
Showing 15 changed files with 446 additions and 40 deletions.
46 changes: 44 additions & 2 deletions core/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from strawberry import auto
from typing import Optional
from strawberry_django.filters import FilterLookup

import strawberry_django
print("Test")


Expand Down Expand Up @@ -114,6 +114,33 @@ class ViewFilter:
is_global: auto
provenance: ProvenanceFilter | None

@strawberry.django.filter(models.PixelLabel)
class PixelLabelFilter:
value: float | None = None
view: strawberry.ID | None = None
entity_kind: strawberry.ID | None = None
entity: strawberry.ID | None = None

def filter_value(self, queryset, info):
if self.value is None:
return queryset
return queryset.filter(value=self.value)

def filter_view(self, queryset, info):
if self.view is None:
return queryset
return queryset.filter(view_id=self.view)

def filter_entity_kind(self, queryset, info):
if self.entity_kind is None:
return queryset
return queryset.filter(entity_entity_kind_id=self.entity_kind)

def filter_entity(self, queryset, info):
if self.entity is None:
return queryset
return queryset.filter(entity_id=self.entity)


@strawberry.django.filter(models.AffineTransformationView)
class AffineTransformationViewFilter(ViewFilter):
Expand All @@ -132,6 +159,9 @@ class TimepointViewFilter(ViewFilter):
ms_since_start: auto
index_since_start: auto

@strawberry.django.filter(models.PixelView)
class PixelViewFilter(ViewFilter):
pass

@strawberry.django.filter(models.OpticsView)
class OpticsViewFilter(ViewFilter):
Expand Down Expand Up @@ -179,21 +209,28 @@ def filter_ids(self, queryset, info):
return queryset.filter(id__in=self.ids)


@strawberry.django.filter(models.Image)
@strawberry_django.filter(models.Image)
class ImageFilter:
name: Optional[FilterLookup[str]]
ids: list[strawberry.ID] | None
store: ZarrStoreFilter | None
dataset: DatasetFilter | None
transformation_views: AffineTransformationViewFilter | None
timepoint_views: TimepointViewFilter | None
not_derived: bool | None = None

provenance: ProvenanceFilter | None

def filter_ids(self, queryset, info):
if self.ids is None:
return queryset
return queryset.filter(id__in=self.ids)

def filter_not_derived(self, queryset, info):
print("Filtering not derived")
if self.not_derived is None:
return queryset
return queryset.filter(origins=None)


@strawberry.django.filter(models.ROI)
Expand Down Expand Up @@ -377,6 +414,11 @@ class SpecimenFilter(IDFilterMixin, SearchFilterMixin):
class ProtocolFilter(IDFilterMixin, SearchFilterMixin):
id: auto

@strawberry.django.filter(models.ProtocolStep)
class ProtocolStepFilter(IDFilterMixin, SearchFilterMixin):
id: auto


@strawberry.django.filter(models.ProtocolStepMapping)
class ProtocolStepMappingFilter(IDFilterMixin, SearchFilterMixin):
id: auto
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Generated by Django 4.2.8 on 2024-08-16 13:16

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("core", "0002_alter_protocolstepmapping_t"),
]

operations = [
migrations.AddField(
model_name="historicalprotocolstep",
name="plate_children",
field=models.JSONField(
blank=True,
default=list,
help_text="The children of the slate",
null=True,
),
),
migrations.AddField(
model_name="protocolstep",
name="plate_children",
field=models.JSONField(
blank=True,
default=list,
help_text="The children of the slate",
null=True,
),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Generated by Django 4.2.8 on 2024-08-18 15:16

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("core", "0003_historicalprotocolstep_plate_children_and_more"),
]

operations = [
migrations.RemoveField(
model_name="entity",
name="parent",
),
migrations.RemoveField(
model_name="historicalentity",
name="parent",
),
]
18 changes: 18 additions & 0 deletions core/migrations/0005_entityrelation_only_one_relation_per_kind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.8 on 2024-08-18 19:44

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("core", "0004_remove_entity_parent_remove_historicalentity_parent"),
]

operations = [
migrations.AddConstraint(
model_name="entityrelation",
constraint=models.UniqueConstraint(
fields=("left", "kind", "right"), name="only_one_relation_per_kind"
),
),
]
32 changes: 23 additions & 9 deletions core/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from django.db import models
from django.contrib.auth import get_user_model
from django.forms import FileField
Expand Down Expand Up @@ -342,6 +343,10 @@ class ProtocolStep(models.Model):
related_name="entities",
help_text="The reagents that were used in this step (you can specifiy properties of the reagents in the entity)",
)
plate_children = models.JSONField(
help_text="The children of the slate", null=True, blank=True,
default=list
)

history = HistoryField()

Expand Down Expand Up @@ -1127,6 +1132,15 @@ class Meta:
def __str__(self) -> str:
return f"{self.label} in {self.ontology}"

def create_entity(self, group, name: str = None, instance_kind: str = None, metrics: dict = None) -> "Entity":
return Entity.objects.create(
name=name or str(uuid.uuid4()),
group=group,
kind=self,
instance_kind=instance_kind,
metrics=metrics or {}
)


class EntityGroup(models.Model):
"""An EntityGroup is a collection of Entities.
Expand Down Expand Up @@ -1176,15 +1190,6 @@ class Entity(models.Model):
related_name="entities",
help_text="The group this entity belongs to",
)

parent = models.ForeignKey(
"self",
on_delete=models.CASCADE,
related_name="parts",
null=True,
blank=True,
help_text="The entity this entity is part of",
)
kind = models.ForeignKey(
EntityKind,
on_delete=models.CASCADE,
Expand Down Expand Up @@ -1245,6 +1250,15 @@ class EntityRelation(models.Model):
help_text="Associated metrics this relation",
)

class Meta:
constraints = [
models.UniqueConstraint(
fields=["left", "kind", "right"],
name="only_one_relation_per_kind",
)
]



class EntityMetric(models.Model):
kind = models.OneToOneField(
Expand Down
1 change: 0 additions & 1 deletion core/mutations/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def create_entity(
group=group,
kind=input_kind,
defaults=dict(
parent=models.Entity.objects.get(id=input.parent) if input.parent else None,
name=input.name or uuid.uuid4().hex,
instance_kind=input.instance_kind,
)
Expand Down
18 changes: 7 additions & 11 deletions core/mutations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
PartialSpecimenViewInput,
PartialAcquisitionViewInput,
PartialAffineTransformationViewInput,
PartialPixelViewInput,
PartialScaleViewInput,
_create_pixel_view_from_partial,
view_kwargs_from_input,
)
from django.conf import settings
Expand Down Expand Up @@ -204,7 +206,7 @@ class FromArrayLikeInput:
channel_views: list[PartialChannelViewInput] | None = None
transformation_views: list[PartialAffineTransformationViewInput] | None = None
acquisition_views: list[PartialAcquisitionViewInput] | None = None
label_views: list[PartialLabelViewInput] | None = None
pixel_views: list[PartialPixelViewInput] | None = None
specimen_views: list[PartialSpecimenViewInput] | None = None
rgb_views: list[PartialRGBViewInput] | None = None
timepoint_views: list[PartialTimepointViewInput] | None = None
Expand Down Expand Up @@ -280,16 +282,6 @@ def from_array_like(
**view_kwargs_from_input(specimenview),
)

if input.label_views is not None:
for labelview in input.label_views:
models.LabelView.objects.create(
image=image,
fluorophore_id=labelview.fluorophore,
primary_antibody_id=labelview.primary_antibody,
secondary_antibody_id=labelview.secondary_antibody,
**view_kwargs_from_input(labelview),
)

if input.scale_views is not None:
for scaleview in input.scale_views:
models.ScaleView.objects.create(
Expand Down Expand Up @@ -370,6 +362,10 @@ def from_array_like(
**view_kwargs_from_input(transformationview),
)

if input.pixel_views:
for pixelview in input.pixel_views:
_create_pixel_view_from_partial(image, pixelview)

return image


Expand Down
64 changes: 62 additions & 2 deletions core/mutations/protocol_step.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from kante.types import Info
import strawberry
from core import types, models
from core import types, models, scalars
import uuid



@strawberry.input
class PlateChildInput:
id: strawberry.ID | None = None
type: str | None = None
text: str | None = None
children: list["PlateChildInput"] | None = None
value: str | None = None
color: str | None = None
fontSize: str | None = None
backgroundColor: str | None = None
bold: bool | None = None
italic: bool | None = None
underline: bool | None = None




@strawberry.input
Expand All @@ -13,6 +28,19 @@ class ProtocolStepInput:
reagents: list[strawberry.ID] | None = None
kind: strawberry.ID
description: str | None = None
plate_children: list[PlateChildInput] | None = None


@strawberry.input
class UpdateProtocolStepInput:
name: str | None = None
id: strawberry.ID
reagents: list[strawberry.ID] | None = None
kind: strawberry.ID | None = None
description: str | None = None
plate_children: list[PlateChildInput] | None = None



@strawberry.input
class MapProtocolStepInput:
Expand All @@ -38,11 +66,12 @@ def create_protocol_step(



step, _ = models.ProtocolStep.objects.update_or_create(
step, _ = models.ProtocolStep.objects.get_or_create(
name=input.name,
defaults=dict(
kind=input_kind,
description=input.description or "",
plate_children=input.plate_children or [],
),
)

Expand All @@ -64,6 +93,37 @@ def map_protocol_step(info: Info, input: MapProtocolStepInput) -> types.Protocol
return mapping


def child_to_str(child):
if child.get("children", []) is None:
return " ".join([child_to_str(c) for c in child["children"]]),
else:
return child.get("value", child.get("text", "")) or ""


def plate_children_to_str(children):
return " ".join([child_to_str(c) for c in children])


def update_protocol_step(
info: Info,
input: UpdateProtocolStepInput,
) -> types.ProtocolStep:
step = models.ProtocolStep.objects.get(id=input.id)
step.name = input.name if input.name else step.name
step.description = input.description if input.description else plate_children_to_str([strawberry.asdict(i) for i in input.plate_children])
step.plate_children = [strawberry.asdict(i) for i in input.plate_children] if input.plate_children else step.plate_children
step.kind = models.EntityKind.objects.get(id=input.kind) if input.kind else step.kind


if input.reagents:

step.reagents.clear()
for reagent in input.reagents:
step.reagents.add(models.Entity.objects.get(id=reagent))

step.save()
return step


def delete_protocol_step(
info: Info,
Expand Down
Loading

0 comments on commit 0c889a8

Please sign in to comment.