Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using brian units - done main and LFP but LFP still fails 10 tests #43

Closed
wants to merge 10 commits into from
18 changes: 9 additions & 9 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from cleo.registry import registry_for_sim
from cleo.utilities import add_to_neo_segment, analog_signal, brian_safe_name

import cleo

class NeoExportable(ABC):
"""Mixin class for classes that can be exported to Neo objects"""
Expand Down Expand Up @@ -158,7 +158,7 @@ class IOProcessor(ABC):
class more useful, since delay handling is already defined.
"""

sample_period_ms: float = 1
sample_period: float = 1 * ms
"""Determines how frequently the processor takes samples"""

latest_ctrl_signal: dict = field(factory=dict, init=False, repr=False)
Expand All @@ -180,7 +180,7 @@ def is_sampling_now(self, time) -> bool:
pass

@abstractmethod
def put_state(self, state_dict: dict, sample_time_ms: float) -> None:
def put_state(self, state_dict: dict, sample_time: float) -> None:
"""Deliver network state to the :class:`IOProcessor`.

Parameters
Expand Down Expand Up @@ -273,7 +273,7 @@ class Stimulator(InterfaceDevice, NeoExportable):
"""The current value of the stimulator device"""
default_value: Any = 0
"""The default value of the device---used on initialization and on :meth:`~reset`"""
t_ms: list[float] = field(factory=list, init=False, repr=False)
t: list[float] = field(factory=list, init=False, repr=False)
"""Times stimulator was updated, stored if :attr:`~cleo.InterfaceDevice.save_history`"""
values: list[Any] = field(factory=list, init=False, repr=False)
"""Values taken by the stimulator at each :meth:`~update` call,
Expand All @@ -286,10 +286,10 @@ def __attrs_post_init__(self):
def _init_saved_vars(self):
if self.save_history:
if self.sim:
t0 = self.sim.network.t / ms
t0 = self.sim.network.t
else:
t0 = 0
self.t_ms = [t0]
t0 = 0*ms
self.t = [t0]
self.values = [self.value]

def update(self, ctrl_signal) -> None:
Expand All @@ -307,7 +307,7 @@ def update(self, ctrl_signal) -> None:
"""
self.value = ctrl_signal
if self.save_history:
self.t_ms.append(self.sim.network.t / ms)
self.t.append(self.sim.network.t)
self.values.append(self.value)

def reset(self, **kwargs) -> None:
Expand All @@ -316,7 +316,7 @@ def reset(self, **kwargs) -> None:
self._init_saved_vars()

def to_neo(self):
signal = analog_signal(self.t_ms, self.values, "dimensionless")
signal = cleo.utilities.analog_signal(self.t, self.values, "dimensionless")
signal.name = self.name
signal.description = "Exported from Cleo stimulator device"
signal.annotate(export_datetime=datetime.datetime.now())
Expand Down
44 changes: 24 additions & 20 deletions cleo/ephys/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import numpy as np
import quantities as pq
from attrs import define, field
from brian2 import NeuronGroup, mm, ms
from brian2 import NeuronGroup, mm, ms, second
from brian2.monitors.spikemonitor import SpikeMonitor
from nptyping import NDArray
from tklfp import TKLFP

import cleo.utilities
from cleo.base import NeoExportable
from cleo.ephys.probes import Signal, Probe
import cleo.utilities
import neo
from cleo.ephys.probes import Signal


Expand Down Expand Up @@ -41,7 +44,7 @@ class TKLFPSignal(Signal, NeoExportable):
to be considered, by default 1e-3.
This determines the buffer length of past spikes, since the uLFP from a long-past
spike becomes negligible and is ignored."""
t_ms: NDArray[(Any,), float] = field(init=False, repr=False)
t: NDArray[(Any,), float] = field(init=False, repr=False)
"""Times at which LFP is recorded, in ms, stored if
:attr:`~cleo.InterfaceDevice.save_history` on :attr:`~Signal.probe`"""
lfp_uV: NDArray[(Any, Any), float] = field(init=False, repr=False)
Expand All @@ -53,7 +56,7 @@ class TKLFPSignal(Signal, NeoExportable):
_monitors: list[SpikeMonitor] = field(init=False, factory=list, repr=False)
_mon_spikes_already_seen: list[int] = field(init=False, factory=list, repr=False)
_i_buffers: list[list[np.ndarray]] = field(init=False, factory=list, repr=False)
_t_ms_buffers: list[list[np.ndarray]] = field(init=False, factory=list, repr=False)
_t_buffers: list[list[np.ndarray]] = field(init=False, factory=list, repr=False)
_buffer_positions: list[int] = field(init=False, factory=list, repr=False)

def _post_init_for_probe(self):
Expand All @@ -65,12 +68,12 @@ def _post_init_for_probe(self):

def _init_saved_vars(self):
if self.probe.save_history:
self.t_ms = np.empty((0,))
self.t = np.empty((0,))*ms
self.lfp_uV = np.empty((0, self.probe.n))

def _update_saved_vars(self, t_ms, lfp_uV):
def _update_saved_vars(self, t, lfp_uV):
if self.probe.save_history:
self.t_ms = np.concatenate([self.t_ms, [t_ms]])
self.t = np.concatenate([self.t, [t]])
self.lfp_uV = np.vstack([self.lfp_uV, lfp_uV])

def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):
Expand Down Expand Up @@ -103,7 +106,7 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):
# prep buffers
self._tklfps.append(tklfp)
self._i_buffers.append([np.array([], dtype=int, ndmin=1)] * buf_len)
self._t_ms_buffers.append([np.array([], dtype=float, ndmin=1)] * buf_len)
self._t_buffers.append([np.array([], dtype=float, ndmin=1)] * buf_len)
self._buffer_positions.append(0)

# prep SpikeMonitor
Expand All @@ -114,13 +117,13 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):

def get_state(self) -> np.ndarray:
tot_tklfp = 0
now_ms = self.probe.sim.network.t / ms
now = self.probe.sim.network.t
# loop over neuron groups (monitors, tklfps)
for i_mon in range(len(self._monitors)):
self._update_spike_buffer(i_mon)
tot_tklfp += self._tklfp_for_monitor(i_mon, now_ms)
tot_tklfp += self._tklfp_for_monitor(i_mon, now)
out = np.reshape(tot_tklfp, (-1,)) # return 1D array (vector)
self._update_saved_vars(now_ms, out)
self._update_saved_vars(now, out)
return out

def reset(self, **kwargs) -> None:
Expand All @@ -133,7 +136,7 @@ def _reset_buffer(self, i_mon):
mon = self._monitors[i_mon]
buf_len = len(self._i_buffers[i_mon])
self._i_buffers[i_mon] = [np.array([], dtype=int, ndmin=1)] * buf_len
self._t_ms_buffers[i_mon] = [np.array([], dtype=float, ndmin=1)] * buf_len
self._t_buffers[i_mon] = [np.array([], dtype=float, ndmin=1)] * buf_len
self._buffer_positions[i_mon] = 0

def _update_spike_buffer(self, i_mon):
Expand All @@ -143,24 +146,25 @@ def _update_spike_buffer(self, i_mon):

# insert new spikes into buffer (overwriting anything previous)
self._i_buffers[i_mon][buf_pos] = mon.i[n_prev:]
self._t_ms_buffers[i_mon][buf_pos] = mon.t[n_prev:] / ms
self._t_buffers[i_mon][buf_pos] = mon.t[n_prev:]

self._mon_spikes_already_seen[i_mon] = mon.num_spikes
# update buffer position
buf_len = len(self._i_buffers[i_mon])
self._buffer_positions[i_mon] = (buf_pos + 1) % buf_len

def _tklfp_for_monitor(self, i_mon, now_ms):
def _tklfp_for_monitor(self, i_mon, now):
i = np.concatenate(self._i_buffers[i_mon])
t_ms = np.concatenate(self._t_ms_buffers[i_mon])
return self._tklfps[i_mon].compute(i, t_ms, [now_ms])
print(self._t_buffers)
t = np.concatenate(self._t_buffers[i_mon])*second
return self._tklfps[i_mon].compute(i, t / ms, [now / ms])

def _get_buffer_length(self, tklfp, **kwparams):
# need sampling period
sample_period_ms = kwparams.get("sample_period_ms", None)
if sample_period_ms is None:
sample_period = kwparams.get("sample_period_ms", None)*ms
if sample_period is None:
try:
sample_period_ms = self.probe.sim.io_processor.sample_period_ms
sample_period = self.probe.sim.io_processor.sample_period
except AttributeError: # probably means sim doesn't have io_processor
raise Exception(
"TKLFP needs to know the sampling period. Either set the simulator's "
Expand All @@ -169,14 +173,14 @@ def _get_buffer_length(self, tklfp, **kwparams):
", tklfp_type=..., sample_period_ms=...)"
)
return np.ceil(
tklfp.compute_min_window_ms(self.uLFP_threshold_uV) / sample_period_ms
tklfp.compute_min_window_ms(self.uLFP_threshold_uV) / (sample_period / ms)
).astype(int)

def to_neo(self) -> neo.AnalogSignal:
# inherit docstring
try:
signal = cleo.utilities.analog_signal(
self.t_ms,
self.t,
self.lfp_uV,
"uV",
)
Expand Down
41 changes: 33 additions & 8 deletions cleo/ioproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Tuple

from brian2 import ms
import numpy as np
from attrs import define, field
from typing import Any, Tuple


from cleo.base import IOProcessor
from cleo.ioproc.delays import Delay
Expand Down Expand Up @@ -130,6 +131,8 @@ class LatencyIOProcessor(IOProcessor):

"fixed" sampling means samples are taken on a fixed schedule,
with no exceptions.



"when idle" sampling means no samples are taken before the previous
sample's output has been delivered. A sample is taken ASAP
Expand Down Expand Up @@ -167,6 +170,28 @@ def _validate_processing(self, attribute, value):
raise ValueError("Invalid processing scheme:", value)

out_buffer: deque[Tuple[dict, float]] = field(factory=deque, init=False, repr=False)
"""
"serial" computes the output time by adding the delay for a sample
onto the output time of the previous sample, rather than the sampling
time. Note this may be of limited
utility because it essentially means the *entire* round trip
cannot be in parallel at all. More realistic is that simply
each block or phase of computation must be serial. If anyone
cares enough about this, it will have to be implemented in the
future.

Note
----
It doesn't make much sense to combine parallel computation
with "when idle" sampling, because "when idle" sampling only produces
one sample at a time to process.

Raises
------
ValueError
For invalid `sampling` or `processing` kwargs
"""


def put_state(self, state_dict: dict, sample_time_ms: float):
self.t_samp_ms.append(sample_time_ms)
Expand All @@ -193,10 +218,10 @@ def _is_currently_idle(self, query_time_ms):

def is_sampling_now(self, query_time_ms):
if self.sampling == "fixed":
if np.isclose(query_time_ms % self.sample_period_ms, 0):
if np.isclose(query_time_ms % (self.sample_period / ms), 0):
return True
elif self.sampling == "when idle":
if query_time_ms % self.sample_period_ms == 0:
if query_time_ms % (self.sample_period / ms) == 0:
if self._is_currently_idle(query_time_ms):
self._needs_off_schedule_sample = False
return True
Expand Down Expand Up @@ -237,8 +262,8 @@ class RecordOnlyProcessor(LatencyIOProcessor):

Use this if all you are doing is recording."""

def __init__(self, sample_period_ms, **kwargs):
super().__init__(sample_period_ms, **kwargs)
def __init__(self, sample_period, **kwargs):
super().__init__(sample_period, **kwargs)

def process(self, state_dict: dict, sample_time_ms: float) -> Tuple[dict, float]:
return ({}, sample_time_ms)
def process(self, state_dict: dict, sample_time: float) -> Tuple[dict, float]:
return ({}, sample_time / ms)
2 changes: 1 addition & 1 deletion cleo/light/light.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def _alpha_cmap_for_wavelength(self, intensity):
)

def to_neo(self):
signal = analog_signal(self.t_ms, self.values, "mW/mm**2")
signal = analog_signal(self.t, self.values, "mW/mm**2")
signal.name = self.name
signal.description = "Exported from Cleo Light device"
signal.annotate(export_datetime=datetime.datetime.now())
Expand Down
Loading