Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] New Taxon models for faster species views #490

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions ami/base/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import datetime
import logging
import typing

from django.db import models
from django.utils import timezone

import ami.tasks

Expand Down Expand Up @@ -27,3 +32,55 @@ def update_calculated_fields(self, *args, **kwargs):

class Meta:
abstract = True


def update_calculated_fields_in_bulk(
qs: models.QuerySet[BaseModel] | None = None,
Model: type[models.Model] | None = None,
pks: list[typing.Any] | None = None,
fields: list[str] = [],
last_updated: datetime.datetime | None = None,
save=True,
) -> int:
"""
This function is called by a migration to update the calculated fields for all instances of a model.
"""
to_update: typing.Iterable[BaseModel] = []

if qs:
Model = qs.model
assert Model is not None, "Either a queryset or model must be specified"

# Ensure the model as a method to update calculated fields
assert hasattr(Model, "update_calculated_fields"), f"{Model} has no method 'update_calculated_fields'"

qs = qs or Model.objects.all() # type: ignore
assert qs is not None

if pks:
qs = qs.filter(pk__in=pks)
if last_updated:
# query for None or before the last updated time
qs = qs.filter(
models.Q(calculated_fields_updated_at__isnull=True)
| models.Q(calculated_fields_updated_at__lte=last_updated)
)

logging.info(f"Updating pre-calculated fields for {len(to_update)} events")

# Shared the updated timestamp for all instances in a bulk update
updated_timestamp = timezone.now()
for instance in qs:
instance.update_calculated_fields(save=False, updated_timestamp=updated_timestamp)
to_update.append(instance)

if save:
logging.info(f"Saving {len(to_update)} instances, only updating {len(fields)} fields: {fields}")
updated_count = Model.objects.bulk_update(
to_update,
fields,
)
if updated_count != len(to_update):
logging.error(f"Failed to update {len(to_update) - updated_count} events")

return updated_count
13 changes: 12 additions & 1 deletion ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ class ProjectAdmin(admin.ModelAdmin[Project]):

list_display = ("name", "priority", "active", "created_at", "updated_at")

# add action to update observed taxa for the project
@admin.action(description="Update observed taxa")
def update_observed_taxa(self, request: HttpRequest, queryset: QuerySet[Project]) -> None:
from ami.taxa.models import update_taxa_observed_for_project

for project in queryset:
update_taxa_observed_for_project(project)
self.message_user(request, f"Updated {queryset.count()} projects.")

actions = [update_observed_taxa]

@admin.action(description="Remove duplicate classifications from all detections")
def _remove_duplicate_classifications(self, request: HttpRequest, queryset: QuerySet[Project]) -> None:
task_ids = []
Expand All @@ -66,7 +77,7 @@ def _remove_duplicate_classifications(self, request: HttpRequest, queryset: Quer
task_ids.append(task.id)
self.message_user(request, f"Started {len(task_ids)} tasks to delete classification: {task_ids}")

actions = [_remove_duplicate_classifications]
actions = [_remove_duplicate_classifications, update_observed_taxa]


@admin.register(Deployment)
Expand Down
52 changes: 4 additions & 48 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ami.ml.serializers import AlgorithmSerializer
from ami.users.models import User
from ami.utils.dates import get_image_timestamp_from_filename
from ami.utils.requests import get_active_classification_threshold

from ..models import (
Classification,
Expand Down Expand Up @@ -133,6 +132,8 @@ class DeploymentListSerializer(DefaultSerializer):
project = ProjectNestedSerializer(read_only=True)
device = DeviceNestedSerializer(read_only=True)
research_site = SiteNestedSerializer(read_only=True)
first_date = serializers.DateField(read_only=True, source="first_capture_timestamp")
last_date = serializers.DateField(read_only=True, source="last_capture_timestamp")

class Meta:
model = Deployment
Expand Down Expand Up @@ -459,60 +460,20 @@ class Meta:


class TaxonListSerializer(DefaultSerializer):
# latest_detection = DetectionNestedSerializer(read_only=True)
occurrences = serializers.SerializerMethodField()
occurrence_images = serializers.SerializerMethodField()
parent = TaxonNestedSerializer(read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")

class Meta:
model = Taxon
fields = [
"id",
"name",
"rank",
"parent",
"parents",
"details",
"occurrences_count",
"occurrences",
"occurrence_images",
"last_detected",
"best_determination_score",
"created_at",
"updated_at",
]

def get_occurrences(self, obj):
"""
Return URL to the occurrences endpoint filtered by this taxon.
"""

params = {}
params.update(dict(self.context["request"].query_params.items()))
params.update({"determination": obj.pk})

return reverse_with_params(
"occurrence-list",
request=self.context.get("request"),
params=params,
)

def get_occurrence_images(self, obj):
"""
Call the occurrence_images method on the Taxon model, with arguments.
"""

# request = self.context.get("request")
# project_id = request.query_params.get("project") if request else None
project_id = self.context["request"].query_params["project"]
classification_threshold = get_active_classification_threshold(self.context["request"])

return obj.occurrence_images(
# @TODO pass the request to generate media url & filter by current user's access
# request=self.context.get("request"),
project_id=project_id,
classification_threshold=classification_threshold,
)


class CaptureTaxonSerializer(DefaultSerializer):
parent = TaxonNoParentNestedSerializer(read_only=True)
Expand Down Expand Up @@ -669,7 +630,6 @@ class Meta:

class TaxonSerializer(DefaultSerializer):
# latest_detection = DetectionNestedSerializer(read_only=True)
occurrences = TaxonOccurrenceNestedSerializer(many=True, read_only=True)
parent = TaxonNoParentNestedSerializer(read_only=True)
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
# parents = TaxonParentNestedSerializer(many=True, read_only=True, source="parents_json")
Expand All @@ -685,10 +645,6 @@ class Meta:
"parent_id",
"parents",
"details",
"occurrences_count",
"detections_count",
"events_count",
"occurrences",
]


Expand Down
24 changes: 15 additions & 9 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ami.base.filters import NullsLastOrderingFilter
from ami.base.pagination import LimitOffsetPaginationWithPermissions
from ami.base.permissions import IsActiveStaffOrReadOnly
from ami.taxa.models import TaxonObserved
from ami.utils.requests import get_active_classification_threshold
from ami.utils.storages import ConnectionTestResult

Expand Down Expand Up @@ -177,6 +178,10 @@ def get_queryset(self) -> QuerySet:
queryset=SourceImage.objects.order_by("created_at").exclude(upload=None),
)
)
elif self.action == "list":
# Add annotations for deployments
# @TODO use a QuerySet manager and call Deployments.with_counts() etc.
pass

return qs

Expand Down Expand Up @@ -859,6 +864,7 @@ def get_serializer_class(self):
else:
return TaxonSerializer

# @TODO this can now be removed since we are using TaxonObservedViewSet
def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
"""
Filter taxa by when/where it has occurred.
Expand Down Expand Up @@ -898,6 +904,7 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
# @TODO need to return the models.Q filter used, so we can use it for counts and related occurrences.
return queryset, filter_active

# @TODO this can now be removed since we are using TaxonObservedViewSet
def filter_by_classification_threshold(self, queryset: QuerySet) -> QuerySet:
"""
Filter taxa by their best determination score in occurrences.
Expand All @@ -917,6 +924,7 @@ def filter_by_classification_threshold(self, queryset: QuerySet) -> QuerySet:

return queryset

# @TODO this can now be removed since we are using TaxonObservedViewSet
def get_occurrences_filters(self, queryset: QuerySet) -> tuple[QuerySet, models.Q]:
# @TODO this should check what the user has access to
project_id = self.request.query_params.get("project")
Expand All @@ -942,6 +950,7 @@ def get_occurrences_filters(self, queryset: QuerySet) -> tuple[QuerySet, models.

return taxon_occurrences_query, taxon_occurrences_count_filter

# @TODO this can now be removed since we are using TaxonObservedViewSet
def add_occurrence_counts(self, queryset: QuerySet, occurrences_count_filter: models.Q) -> QuerySet:
qs = queryset.annotate(
occurrences_count=models.Count(
Expand All @@ -953,10 +962,12 @@ def add_occurrence_counts(self, queryset: QuerySet, occurrences_count_filter: mo
)
return qs

# @TODO this can now be removed since we are using TaxonObservedViewSet
def add_filtered_occurrences(self, queryset: QuerySet, occurrences_query: QuerySet) -> QuerySet:
qs = queryset.prefetch_related(Prefetch("occurrences", queryset=occurrences_query))
return qs

# @TODO this can now be removed since we are using TaxonObservedViewSet
def zero_occurrences(self, queryset: QuerySet) -> QuerySet:
"""
Return a queryset with zero occurrences but compatible with the original queryset.
Expand Down Expand Up @@ -1047,13 +1058,10 @@ def get(self, request):
determination_score__gte=confidence_threshold,
event__isnull=False,
).count(),
"taxa_count": Taxon.objects.annotate(occurrences_count=models.Count("occurrences"))
"taxa_count": TaxonObserved.objects.filter(project=project)
.filter(
occurrences_count__gt=0,
occurrences__determination_score__gte=confidence_threshold,
occurrences__project=project,
best_determination_score__gte=confidence_threshold,
)
.distinct()
.count(),
}
else:
Expand All @@ -1066,9 +1074,7 @@ def get(self, request):
"occurrences_count": Occurrence.objects.filter(
determination_score__gte=confidence_threshold, event__isnull=False
).count(),
"taxa_count": Taxon.objects.annotate(occurrences_count=models.Count("occurrences"))
.filter(occurrences_count__gt=0, occurrences__determination_score__gte=confidence_threshold)
.count(),
"taxa_count": TaxonObserved.objects.filter(best_determination_score__gte=confidence_threshold).count(),
"last_updated": timezone.now(),
}

Expand All @@ -1088,7 +1094,7 @@ def get(self, request):


_STORAGE_CONNECTION_STATUS = [
# These come from the ConnetionStatus react component
# These come from the ConnectionStatus react component
# @TODO use ENUM
"NOT_CONNECTED",
"CONNECTING",
Expand Down
15 changes: 7 additions & 8 deletions ami/main/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,17 @@ def event_detections_per_hour(event_pk: int):


def event_top_taxa(event_pk: int, top_n: int = 10):
# Horiziontal bar chart of top taxa
Taxon = apps.get_model("main", "Taxon")
# Horizontal bar chart of top taxa
TaxonObserved = apps.get_model("taxa", "TaxonObserved")
top_taxa = (
Taxon.objects.filter(occurrences__event=event_pk)
.values("name")
# .annotate(num_detections=models.Count("occurrences__detections"))
.annotate(num_detections=models.Count("occurrences"))
.order_by("-num_detections")[:top_n]
TaxonObserved.objects.filter(occurrences__event=event_pk)
.select_related("taxon")
.values("taxon__name", "occurrences_count")
.order_by("-occurrences_count")[:top_n]
)

if top_taxa:
taxa, counts = list(zip(*[(t["name"], t["num_detections"]) for t in top_taxa]))
taxa, counts = list(zip(*[(t["name"], t["occurrences_count"]) for t in top_taxa]))
taxa = [t or "Unknown" for t in taxa]
counts = [c or 0 for c in counts]
else:
Expand Down
Loading
Loading