From f61bf569f3f91f7b6e606e911aa6781c1af64360 Mon Sep 17 00:00:00 2001 From: Kyle Johnsen Date: Thu, 11 Apr 2024 12:19:37 -0400 Subject: [PATCH] patches: fix Spiking.t_samp_ms and IOProc sample skipping (#47) * fix Spiking.t_samp_ms * fix sample time test bug --- cleo/ephys/spiking.py | 3 ++- cleo/ioproc/base.py | 3 ++- pyproject.toml | 2 +- tests/ephys/test_spiking.py | 15 +++++++++------ tests/ioproc/test_processing.py | 31 +++++++++++++++++++++++-------- 5 files changed, 37 insertions(+), 17 deletions(-) diff --git a/cleo/ephys/spiking.py b/cleo/ephys/spiking.py index 77d558e..e0c8054 100644 --- a/cleo/ephys/spiking.py +++ b/cleo/ephys/spiking.py @@ -65,7 +65,8 @@ def _update_saved_vars(self, t_ms, i, t_samp_ms): if self.probe.save_history: self.i = np.concatenate([self.i, i]) self.t_ms = np.concatenate([self.t_ms, t_ms]) - self.t_samp_ms = np.concatenate([self.t_samp_ms, [t_samp_ms]]) + t_samp_ms_rep = np.full_like(t_ms, t_samp_ms) + self.t_samp_ms = np.concatenate([self.t_samp_ms, t_samp_ms_rep]) def connect_to_neuron_group( self, neuron_group: NeuronGroup, **kwparams diff --git a/cleo/ioproc/base.py b/cleo/ioproc/base.py index 78c0a6a..e0e070a 100644 --- a/cleo/ioproc/base.py +++ b/cleo/ioproc/base.py @@ -193,7 +193,8 @@ def _is_currently_idle(self, query_time_ms): def is_sampling_now(self, query_time_ms): if self.sampling == "fixed": - if np.isclose(query_time_ms % self.sample_period_ms, 0): + resid = query_time_ms % self.sample_period_ms + if np.isclose(resid, 0) or np.isclose(resid, self.sample_period_ms): return True elif self.sampling == "when idle": if np.isclose(query_time_ms % self.sample_period_ms, 0): diff --git a/pyproject.toml b/pyproject.toml index 9a77cc0..271a08f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cleosim" -version = "0.14.1" +version = "0.14.2" description = "Cleo: the Closed-Loop, Electrophysiology, and Optogenetics experiment simulation testbed" authors = [ "Kyle Johnsen ", diff --git a/tests/ephys/test_spiking.py b/tests/ephys/test_spiking.py index c8d5432..370bdd4 100644 --- a/tests/ephys/test_spiking.py +++ b/tests/ephys/test_spiking.py @@ -1,12 +1,10 @@ """Tests for ephys.spiking module""" -import pytest import numpy as np -from brian2 import SpikeGeneratorGroup, ms, mm, Network -import neo import quantities as pq +from brian2 import Network, SpikeGeneratorGroup, mm, ms from cleo import CLSimulator -from cleo.ephys import * +from cleo.ephys import MultiUnitSpiking, Probe, SortedSpiking from cleo.ioproc import RecordOnlyProcessor @@ -81,7 +79,7 @@ def test_MUS_multiple_contacts(): assert np.sum(mus.i == 0) < len(indices) assert np.sum(mus.i == 1) < len(indices) - assert len(mus.i) == len(mus.t_ms) + assert len(mus.i) == len(mus.t_ms) == len(mus.t_samp_ms) def test_MUS_multiple_groups(): @@ -106,6 +104,8 @@ def test_MUS_multiple_groups(): assert 20 < np.sum(mus.i == 0) < 60 # second channel would have caught all spikes from sgg1 and sgg2 assert np.sum(mus.i == 1) == 60 + assert len(mus.t_ms) == len(mus.t_samp_ms) + assert np.all(mus.t_samp_ms == 10) def test_MUS_reset(): @@ -156,7 +156,10 @@ def test_SortedSpiking(): assert all(i == [2, 3, 5]) for i in (0, 1, 4): - assert not i in ss.i + assert i not in ss.i + + assert ss.t_ms.shape == ss.i.shape == ss.t_samp_ms.shape + assert np.all(np.in1d(ss.t_samp_ms, [3, 4, 5, 6])) def _test_reset(spike_signal_class): diff --git a/tests/ioproc/test_processing.py b/tests/ioproc/test_processing.py index 457ac40..47ddf1d 100644 --- a/tests/ioproc/test_processing.py +++ b/tests/ioproc/test_processing.py @@ -52,13 +52,18 @@ def __init__(self, sample_period_ms, **kwargs): super().__init__(sample_period_ms, **kwargs) self.delay = 1.199 self.component = MyProcessingBlock(delay=ConstantDelay(self.delay)) + self.count = 0 def process(self, state_dict: dict, sample_time_ms: float) -> Tuple[dict, float]: - input = state_dict["in"] - out, out_t = self.component.process( - input, sample_time_ms, measurement_time=sample_time_ms - ) - return {"out": out}, out_t + try: + input = state_dict["in"] + out, out_t = self.component.process( + input, sample_time_ms, measurement_time=sample_time_ms + ) + return {"out": out}, out_t + except KeyError: + self.count += 1 + return {}, sample_time_ms def _test_LatencyIOProcessor(myLIOP, t, sampling, inputs, outputs): @@ -117,14 +122,13 @@ class SampleCounter(cleo.IOProcessor): def is_sampling_now(self, t_query_ms) -> np.bool: return t_query_ms % self.sample_period_ms == 0 - def __init__(self): + def __init__(self, sample_period_ms=1): self.count = 0 - self.sample_period_ms = 1 + self.sample_period_ms = sample_period_ms self.latest_ctrl_signal = {} def put_state(self, state_dict: dict, sample_time_ms: float): self.count += 1 - print(sample_time_ms) return ({}, sample_time_ms) def get_ctrl_signals(self, query_time_ms: np.float) -> dict: @@ -141,6 +145,17 @@ def test_no_skip_sampling(): assert sc.count == nsamp +def test_no_skip_sampling_short(): + net = Network() + sim = cleo.CLSimulator(net) + Tsamp = 0.2 * ms + liop = MyLIOP(Tsamp / ms) + sim.set_io_processor(liop) + nsamp = 20 + sim.run(nsamp * Tsamp) + assert liop.count == nsamp + + class WaveformController(LatencyIOProcessor): def process(self, state_dict, t_ms): return {"steady": t_ms, "time-varying": t_ms + 1}, t_ms + 3