Skip to content

Commit

Permalink
patches: fix Spiking.t_samp_ms and IOProc sample skipping (#47)
Browse files Browse the repository at this point in the history
* fix Spiking.t_samp_ms

* fix sample time test bug
  • Loading branch information
kjohnsen authored Apr 11, 2024
1 parent 1b4553d commit f61bf56
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 17 deletions.
3 changes: 2 additions & 1 deletion cleo/ephys/spiking.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def _update_saved_vars(self, t_ms, i, t_samp_ms):
if self.probe.save_history:
self.i = np.concatenate([self.i, i])
self.t_ms = np.concatenate([self.t_ms, t_ms])
self.t_samp_ms = np.concatenate([self.t_samp_ms, [t_samp_ms]])
t_samp_ms_rep = np.full_like(t_ms, t_samp_ms)
self.t_samp_ms = np.concatenate([self.t_samp_ms, t_samp_ms_rep])

def connect_to_neuron_group(
self, neuron_group: NeuronGroup, **kwparams
Expand Down
3 changes: 2 additions & 1 deletion cleo/ioproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def _is_currently_idle(self, query_time_ms):

def is_sampling_now(self, query_time_ms):
if self.sampling == "fixed":
if np.isclose(query_time_ms % self.sample_period_ms, 0):
resid = query_time_ms % self.sample_period_ms
if np.isclose(resid, 0) or np.isclose(resid, self.sample_period_ms):
return True
elif self.sampling == "when idle":
if np.isclose(query_time_ms % self.sample_period_ms, 0):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cleosim"
version = "0.14.1"
version = "0.14.2"
description = "Cleo: the Closed-Loop, Electrophysiology, and Optogenetics experiment simulation testbed"
authors = [
"Kyle Johnsen <kyle@kjohnsen.org>",
Expand Down
15 changes: 9 additions & 6 deletions tests/ephys/test_spiking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Tests for ephys.spiking module"""
import pytest
import numpy as np
from brian2 import SpikeGeneratorGroup, ms, mm, Network
import neo
import quantities as pq
from brian2 import Network, SpikeGeneratorGroup, mm, ms

from cleo import CLSimulator
from cleo.ephys import *
from cleo.ephys import MultiUnitSpiking, Probe, SortedSpiking
from cleo.ioproc import RecordOnlyProcessor


Expand Down Expand Up @@ -81,7 +79,7 @@ def test_MUS_multiple_contacts():
assert np.sum(mus.i == 0) < len(indices)
assert np.sum(mus.i == 1) < len(indices)

assert len(mus.i) == len(mus.t_ms)
assert len(mus.i) == len(mus.t_ms) == len(mus.t_samp_ms)


def test_MUS_multiple_groups():
Expand All @@ -106,6 +104,8 @@ def test_MUS_multiple_groups():
assert 20 < np.sum(mus.i == 0) < 60
# second channel would have caught all spikes from sgg1 and sgg2
assert np.sum(mus.i == 1) == 60
assert len(mus.t_ms) == len(mus.t_samp_ms)
assert np.all(mus.t_samp_ms == 10)


def test_MUS_reset():
Expand Down Expand Up @@ -156,7 +156,10 @@ def test_SortedSpiking():
assert all(i == [2, 3, 5])

for i in (0, 1, 4):
assert not i in ss.i
assert i not in ss.i

assert ss.t_ms.shape == ss.i.shape == ss.t_samp_ms.shape
assert np.all(np.in1d(ss.t_samp_ms, [3, 4, 5, 6]))


def _test_reset(spike_signal_class):
Expand Down
31 changes: 23 additions & 8 deletions tests/ioproc/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ def __init__(self, sample_period_ms, **kwargs):
super().__init__(sample_period_ms, **kwargs)
self.delay = 1.199
self.component = MyProcessingBlock(delay=ConstantDelay(self.delay))
self.count = 0

def process(self, state_dict: dict, sample_time_ms: float) -> Tuple[dict, float]:
input = state_dict["in"]
out, out_t = self.component.process(
input, sample_time_ms, measurement_time=sample_time_ms
)
return {"out": out}, out_t
try:
input = state_dict["in"]
out, out_t = self.component.process(
input, sample_time_ms, measurement_time=sample_time_ms
)
return {"out": out}, out_t
except KeyError:
self.count += 1
return {}, sample_time_ms


def _test_LatencyIOProcessor(myLIOP, t, sampling, inputs, outputs):
Expand Down Expand Up @@ -117,14 +122,13 @@ class SampleCounter(cleo.IOProcessor):
def is_sampling_now(self, t_query_ms) -> np.bool:
return t_query_ms % self.sample_period_ms == 0

def __init__(self):
def __init__(self, sample_period_ms=1):
self.count = 0
self.sample_period_ms = 1
self.sample_period_ms = sample_period_ms
self.latest_ctrl_signal = {}

def put_state(self, state_dict: dict, sample_time_ms: float):
self.count += 1
print(sample_time_ms)
return ({}, sample_time_ms)

def get_ctrl_signals(self, query_time_ms: np.float) -> dict:
Expand All @@ -141,6 +145,17 @@ def test_no_skip_sampling():
assert sc.count == nsamp


def test_no_skip_sampling_short():
net = Network()
sim = cleo.CLSimulator(net)
Tsamp = 0.2 * ms
liop = MyLIOP(Tsamp / ms)
sim.set_io_processor(liop)
nsamp = 20
sim.run(nsamp * Tsamp)
assert liop.count == nsamp


class WaveformController(LatencyIOProcessor):
def process(self, state_dict, t_ms):
return {"steady": t_ms, "time-varying": t_ms + 1}, t_ms + 3
Expand Down

0 comments on commit f61bf56

Please sign in to comment.