Skip to content

Commit

Permalink
Add pytests for decoding pipeline (#1155)
Browse files Browse the repository at this point in the history
* WIP: Add decoding pytests 1

* WIP: add decoding tests 2

* WIP: coverage for v1 schemas

* WIP: fixing impacted tests elsewhere

* ✅ : fix impacted tests

* Revert merge edits
  • Loading branch information
CBroz1 authored Nov 26, 2024
1 parent 4231e51 commit 37ddfc1
Show file tree
Hide file tree
Showing 22 changed files with 988 additions and 259 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add testing for python versions 3.9, 3.10, 3.11, 3.12 #1169
- Initialize tables in pytests #1181
- Download test data without credentials, trigger on approved PRs #1180
- Add coverage of decoding pipeline to pytests #1155
- Allow python \< 3.13 #1169
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ omit = [ # which submodules have no tests
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/decoding/v0/*",
# "*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
# "*/linearization/*",
Expand Down
48 changes: 16 additions & 32 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,53 +85,41 @@ def cleanup(self, dry_run=False):
@classmethod
def fetch_results(cls, key):
"""Fetch the decoding results for a given key."""
return cls().merge_get_parent_class(key).fetch_results()
return cls().merge_restrict_class(key).fetch_results()

@classmethod
def fetch_model(cls, key):
"""Fetch the decoding model for a given key."""
return cls().merge_get_parent_class(key).fetch_model()
return cls().merge_restrict_class(key).fetch_model()

@classmethod
def fetch_environments(cls, key):
"""Fetch the decoding environments for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_environments(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_environments(decoding_selection_key)

@classmethod
def fetch_position_info(cls, key):
"""Fetch the decoding position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_position_info(decoding_selection_key)

@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the decoding linear position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_linear_position_info(decoding_selection_key)

@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the decoding spike data for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(
decoding_selection_key, filter_by_interval=filter_by_interval
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_spike_data(
decoding_selection_key, filter_by_interval=filter_by_interval
)

@classmethod
Expand Down Expand Up @@ -167,11 +155,7 @@ def create_decoding_view(cls, key, head_direction_name="head_orientation"):
head_dir=position_info[head_direction_name],
)
else:
(
position_info,
position_variable_names,
) = cls.fetch_linear_position_info(key)
return create_1D_decode_view(
posterior=posterior,
linear_position=position_info["linear_position"],
linear_position=cls.fetch_linear_position_info(key),
)
10 changes: 6 additions & 4 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def create_group(
"waveform_features_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # No error on duplicate helps with pytests
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one",
)
return
self.insert1(
group_key,
skip_duplicates=True,
Expand Down Expand Up @@ -586,7 +587,8 @@ def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
classifier.environments[0].track_graph, *traj_data
)
else:
position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
# `fetch_position_info` returns a tuple
position_info = self.fetch_position_info(self.fetch1("KEY"))[0].loc[
time_slice
]
map_position = analysis.maximum_a_posteriori_estimate(posterior)
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
restore_classes,
)
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.utils import SpyglassMixin, SpyglassMixinPart
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger

schema = dj.schema("decoding_core_v1")

Expand Down Expand Up @@ -56,14 +56,15 @@ class DecodingParameters(SpyglassMixin, dj.Lookup):
@classmethod
def insert_default(cls):
"""Insert default decoding parameters"""
cls.insert(cls.contents, skip_duplicates=True)
cls.super().insert(cls.contents, skip_duplicates=True)

def insert(self, rows, *args, **kwargs):
"""Override insert to convert classes to dict before inserting"""
for row in rows:
row["decoding_params"] = convert_classes_to_dict(
vars(row["decoding_params"])
)
params = row["decoding_params"]
if hasattr(params, "__dict__"):
params = vars(params)
row["decoding_params"] = convert_classes_to_dict(params)
super().insert(rows, *args, **kwargs)

def fetch(self, *args, **kwargs):
Expand Down Expand Up @@ -124,10 +125,11 @@ def create_group(
"position_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
)
return
self.insert1(
{
**group_key,
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/analysis/v1/unit_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_annotation(self, key, **kwargs):
).fetch_nwb()[0]
nwb_field_name = _get_spike_obj_name(nwb_file)
spikes = nwb_file[nwb_field_name]["spike_times"].to_list()
if key["unit_id"] > len(spikes):
if key["unit_id"] > len(spikes) and not self._test_mode:
raise ValueError(
f"unit_id {key['unit_id']} is greater than ",
f"the number of units in {key['spikesorting_merge_id']}",
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
source: Union[str, dict, dj.Table]
Accepts a CamelCase name of the source, or key as a dict, or a part
table.
init: bool, optional
Default False. If True, returns an instance of the class.
Returns
-------
Expand Down
7 changes: 5 additions & 2 deletions tests/common/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def interval_list(common):


def test_plot_intervals(mini_insert, interval_list):
fig = interval_list.plot_intervals(return_fig=True)
fig = (interval_list & 'interval_list_name LIKE "raw%"').plot_intervals(
return_fig=True
)
interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text()
times_fetch = (
interval_list & {"interval_list_name": interval_list_name}
Expand All @@ -19,7 +21,8 @@ def test_plot_intervals(mini_insert, interval_list):


def test_plot_epoch(mini_insert, interval_list):
fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True)
restr_interval = interval_list & "interval_list_name like 'raw%'"
fig = restr_interval.plot_epoch_pos_raw_intervals(return_fig=True)
epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text()
assert epoch_label == "epoch", "plot_epoch failed"

Expand Down
173 changes: 173 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,3 +1299,176 @@ def dlc_key(sgp, dlc_selection):
def populate_dlc(sgp, dlc_key):
sgp.v1.DLCPosV1().populate(dlc_key)
yield


# ----------------------- FIXTURES, SPIKESORTING TABLES -----------------------
# ------------------------ Note: Used in decoding tests ------------------------


@pytest.fixture(scope="session")
def spike_v1(common):
from spyglass.spikesorting import v1

yield v1


@pytest.fixture(scope="session")
def pop_rec(spike_v1, mini_dict, team_name):
spike_v1.SortGroup.set_group_by_shank(**mini_dict)
key = {
**mini_dict,
"sort_group_id": 0,
"preproc_param_name": "default",
"interval_list_name": "01_s1",
"team_name": team_name,
}
spike_v1.SpikeSortingRecordingSelection.insert_selection(key)
ssr_pk = (
(spike_v1.SpikeSortingRecordingSelection & key).proj().fetch1("KEY")
)
spike_v1.SpikeSortingRecording.populate(ssr_pk)

yield ssr_pk


@pytest.fixture(scope="session")
def pop_art(spike_v1, mini_dict, pop_rec):
key = {
"recording_id": pop_rec["recording_id"],
"artifact_param_name": "default",
}
spike_v1.ArtifactDetectionSelection.insert_selection(key)
spike_v1.ArtifactDetection.populate()

yield spike_v1.ArtifactDetection().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def spike_merge(spike_v1):
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput

yield SpikeSortingOutput()


@pytest.fixture(scope="session")
def sorter_dict():
return {"sorter": "mountainsort4"}


@pytest.fixture(scope="session")
def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
key = {
**mini_dict,
**sorter_dict,
"recording_id": pop_rec["recording_id"],
"interval_list_name": str(pop_art["artifact_id"]),
"sorter_param_name": "franklab_tetrode_hippocampus_30KHz",
}
spike_v1.SpikeSortingSelection.insert_selection(key)
spike_v1.SpikeSorting.populate()

yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def sorting_objs(spike_v1, pop_sort):
sort_nwb = (spike_v1.SpikeSorting & pop_sort).fetch_nwb()
sort_si = spike_v1.SpikeSorting.get_sorting(pop_sort)
yield sort_nwb, sort_si


@pytest.fixture(scope="session")
def pop_curation(spike_v1, pop_sort):
spike_v1.CurationV1.insert_curation(
sorting_id=pop_sort["sorting_id"],
description="testing sort",
)

yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch(
"KEY", as_dict=True
)[0]


@pytest.fixture(scope="session")
def pop_metric(spike_v1, pop_sort, pop_curation):
_ = pop_curation # make sure this happens first
key = {
"sorting_id": pop_sort["sorting_id"],
"curation_id": 0,
"waveform_param_name": "default_not_whitened",
"metric_param_name": "franklab_default",
"metric_curation_param_name": "default",
}

spike_v1.MetricCurationSelection.insert_selection(key)
spike_v1.MetricCuration.populate(key)

yield spike_v1.MetricCuration().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def metric_objs(spike_v1, pop_metric):
key = {"metric_curation_id": pop_metric["metric_curation_id"]}
labels = spike_v1.MetricCuration.get_labels(key)
merge_groups = spike_v1.MetricCuration.get_merge_groups(key)
metrics = spike_v1.MetricCuration.get_metrics(key)
yield labels, merge_groups, metrics


@pytest.fixture(scope="session")
def pop_curation_metric(spike_v1, pop_metric, metric_objs):
labels, merge_groups, metrics = metric_objs
parent_dict = {"parent_curation_id": 0}
spike_v1.CurationV1.insert_curation(
sorting_id=(
spike_v1.MetricCurationSelection
& {"metric_curation_id": pop_metric["metric_curation_id"]}
).fetch1("sorting_id"),
**parent_dict,
labels=labels,
merge_groups=merge_groups,
metrics=metrics,
description="after metric curation",
)

yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def pop_spike_merge(
spike_v1, pop_curation_metric, spike_merge, mini_dict, sorter_dict
):
# TODO: add figurl fixtures when kachery_cloud is initialized

spike_merge.insert([pop_curation_metric], part_name="CurationV1")

yield (spike_merge << pop_curation_metric).fetch1("KEY")


@pytest.fixture(scope="session")
def spike_v1_group():
from spyglass.spikesorting.analysis.v1 import group

yield group


@pytest.fixture(scope="session")
def group_name():
yield "test_group"


@pytest.fixture(scope="session")
def pop_spikes_group(
group_name, spike_v1_group, spike_merge, mini_dict, pop_spike_merge
):

_ = pop_spike_merge # make sure this happens first

spike_v1_group.UnitSelectionParams().insert_default()
spike_v1_group.SortedSpikesGroup().create_group(
**mini_dict,
group_name=group_name,
keys=spike_merge.proj(spikesorting_merge_id="merge_id").fetch("KEY"),
unit_filter_params_name="default_exclusion",
)
yield spike_v1_group.SortedSpikesGroup().fetch("KEY", as_dict=True)[0]
Empty file added tests/decoding/__init__.py
Empty file.
Loading

0 comments on commit 37ddfc1

Please sign in to comment.