diff --git a/derotation/analysis/derotation_pipeline.py b/derotation/analysis/derotation_pipeline.py index 06383e6..35c5121 100644 --- a/derotation/analysis/derotation_pipeline.py +++ b/derotation/analysis/derotation_pipeline.py @@ -154,22 +154,11 @@ def process_analog_signals(self): self.full_rotation, self.k ) self.rot_blocks_idx = self.correct_start_and_end_rotation_signal( - self.inter_rotation_interval_min_len, start, end - ) - self.rotation_on = self.create_signed_rotation_array( - len(self.full_rotation), - self.rot_blocks_idx["start"], - self.rot_blocks_idx["end"], - self.direction, + start, end ) + self.rotation_on = self.create_signed_rotation_array() - self.rotation_ticks_peaks = self.drop_ticks_outside_of_rotation( - self.rotation_ticks_peaks, - self.rot_blocks_idx["start"], - self.rot_blocks_idx["end"], - self.total_clock_time, - self.number_of_rotations, - ) + self.drop_ticks_outside_of_rotation() self.check_number_of_rotations() if not self.is_number_of_ticks_correct() and self.adjust_increment: @@ -177,12 +166,7 @@ def process_analog_signals(self): ( self.corrected_increments, self.ticks_per_rotation, - ) = self.adjust_rotation_increment( - self.rotation_ticks_peaks, - self.rot_blocks_idx["start"], - self.rot_blocks_idx["end"], - self.rot_deg, - ) + ) = self.adjust_rotation_increment() else: self.corrected_increments = ( self.adjust_rotation_increment_for_incremental_changes() @@ -276,9 +260,8 @@ def get_start_end_times_with_threshold( return start, end - @staticmethod def correct_start_and_end_rotation_signal( - inter_rotation_interval_min_len: int, + self, start: np.ndarray, end: np.ndarray, ) -> dict: @@ -287,11 +270,12 @@ def correct_start_and_end_rotation_signal( periods that are not plausible given the experimental setup. The two surrounding on periods are merged. + Used the inter_rotation_interval_min_len parameter from the config + file: the minimum length of the time in between two rotations. + It is important to remove artifacts. + Parameters ---------- - inter_rotation_interval_min_len : int - Minimum length of the time in between two rotations. - It is important to remove artifacts. start : np.ndarray The start times of the on periods of rotation signal. end : np.ndarray @@ -306,7 +290,7 @@ def correct_start_and_end_rotation_signal( logging.info("Cleaning start and end rotation signal...") shifted_end = np.roll(end, 1) - mask = start - shifted_end > inter_rotation_interval_min_len + mask = start - shifted_end > self.inter_rotation_interval_min_len mask[0] = True # first rotation is always a full rotation shifted_mask = np.roll(mask, -1) new_start = start[mask] @@ -314,25 +298,13 @@ def correct_start_and_end_rotation_signal( return {"start": new_start, "end": new_end} - @staticmethod - def create_signed_rotation_array( - len_full_rotation: int, starts: np.ndarray, ends: np.ndarray, direction - ) -> np.ndarray: + def create_signed_rotation_array(self) -> np.ndarray: """Reconstructs an array that has the same length as the full rotation signal. It is 0 when the motor is off, and it is 1 or -1 when the motor is on, depending on the direction of rotation. 1 is clockwise, -1 is counter clockwise. - - Parameters - ---------- - len_full_rotation : int - Length of the full rotation signal. - starts : np.ndarray - The start times of the on periods of rotation signal. - ends : np.ndarray - The end times of the on periods of rotation signal. - direction : _type_ - The direction of rotation of the motor. + Uses the start and end times of the on periods of rotation signal, and + the direction of rotation to reconstruct the array. Returns ------- @@ -341,39 +313,19 @@ def create_signed_rotation_array( """ logging.info("Creating signed rotation array...") - rotation_on = np.zeros(len_full_rotation) + rotation_on = np.zeros(self.total_clock_time) for i, (start, end) in enumerate( zip( - starts, - ends, + self.rot_blocks_idx["start"], + self.rot_blocks_idx["end"], ) ): - rotation_on[start:end] = direction[i] + rotation_on[start:end] = self.direction[i] return rotation_on - @staticmethod - def drop_ticks_outside_of_rotation( - rotation_ticks_peaks: np.ndarray, - starts: np.ndarray, - ends: np.ndarray, - full_length: int, - number_of_rotations: int, - ) -> np.ndarray: - """_summary_ - - Parameters - ---------- - rotation_ticks_peaks : np.ndarray - The clock times of the rotation ticks peaks. - starts : np.ndarray - The start times of the on periods of rotation signal. - ends : np.ndarray - The end times of the on periods of rotation signal. - full_length : int - The length of the analog signals, in clock time. - number_of_rotations : int - The number of rotations. + def drop_ticks_outside_of_rotation(self) -> np.ndarray: + """Drops the rotation ticks that are outside of the rotation periods. Returns ------- @@ -384,33 +336,33 @@ def drop_ticks_outside_of_rotation( logging.info("Dropping ticks outside of the rotation period...") - len_before = len(rotation_ticks_peaks) + len_before = len(self.rotation_ticks_peaks) - rolled_starts = np.roll(starts, -1) - rolled_starts[-1] = full_length + rolled_starts = np.roll(self.rot_blocks_idx["start"], -1) + rolled_starts[-1] = self.total_clock_time inter_roatation_interval = [ idx - for i in range(number_of_rotations) + for i in range(self.number_of_rotations) for idx in range( - ends[i], + self.rot_blocks_idx["end"][i], rolled_starts[i], ) ] - rotation_ticks_peaks = np.delete( - rotation_ticks_peaks, - np.where(np.isin(rotation_ticks_peaks, inter_roatation_interval)), + self.rotation_ticks_peaks = np.delete( + self.rotation_ticks_peaks, + np.where( + np.isin(self.rotation_ticks_peaks, inter_roatation_interval) + ), ) - len_after = len(rotation_ticks_peaks) + len_after = len(self.rotation_ticks_peaks) logging.info( f"Ticks dropped: {len_before - len_after}.\n" + f"Ticks remaining: {len_after}" ) - return rotation_ticks_peaks - def check_number_of_rotations(self): """Checks that the number of rotations is as expected. @@ -460,13 +412,7 @@ def is_number_of_ticks_correct(self) -> bool: ) return False - @staticmethod - def adjust_rotation_increment( - rotation_ticks_peaks: np.ndarray, - starts: np.ndarray, - ends: np.ndarray, - rot_deg: int, - ) -> Tuple[np.ndarray, np.ndarray]: + def adjust_rotation_increment(self) -> Tuple[np.ndarray, np.ndarray]: """It calculates the new rotation increment for each rotation, given the number of ticks in each rotation. It also outputs the number of ticks in each rotation. @@ -492,19 +438,19 @@ def adjust_rotation_increment( def get_peaks_in_rotation(start, end): return np.where( np.logical_and( - rotation_ticks_peaks > start, - rotation_ticks_peaks < end, + self.rotation_ticks_peaks > start, + self.rotation_ticks_peaks < end, ) )[0].shape[0] ticks_per_rotation = [ get_peaks_in_rotation(start, end) for start, end in zip( - starts, - ends, + self.rot_blocks_idx["start"], + self.rot_blocks_idx["end"], ) ] - new_increments = [rot_deg / t for t in ticks_per_rotation] + new_increments = [self.rot_deg / t for t in ticks_per_rotation] logging.info(f"New increment example: {new_increments[0]:.3f}") diff --git a/tests/conftest.py b/tests/conftest.py index 078db79..0066b28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from derotation.analysis.derotation_pipeline import DerotationPipeline + @pytest.fixture(autouse=True) def random(): @@ -115,3 +117,29 @@ def rotation_ticks( ) ticks = np.sort(ticks) return ticks + + +@pytest.fixture +def derotation_pipeline( + rotation_ticks, + start_end_times, + full_length, + number_of_rotations, + full_rotation, + direction, +): + pipeline = DerotationPipeline.__new__(DerotationPipeline) + + pipeline.inter_rotation_interval_min_len = 50 + pipeline.rotation_ticks_peaks = rotation_ticks + pipeline.rot_blocks_idx = { + "start": start_end_times[0], + "end": start_end_times[1], + } + pipeline.number_of_rotations = number_of_rotations + pipeline.direction = direction + pipeline.total_clock_time = full_length + pipeline.full_rotation = full_rotation + pipeline.rot_deg = 360 + + return pipeline diff --git a/tests/test_unit/test_adjust_rotation_increment.py b/tests/test_unit/test_adjust_rotation_increment.py index 0ca6611..7296974 100644 --- a/tests/test_unit/test_adjust_rotation_increment.py +++ b/tests/test_unit/test_adjust_rotation_increment.py @@ -4,21 +4,12 @@ def test_adjust_rotation_increment_360( - rotation_ticks, - start_end_times, + derotation_pipeline: DerotationPipeline, ): - start, end = start_end_times - rot_deg = 360 - ( new_increments, ticks_per_rotation, - ) = DerotationPipeline.adjust_rotation_increment( - rotation_ticks, - start, - end, - rot_deg, - ) + ) = derotation_pipeline.adjust_rotation_increment() new_increments = np.round(new_increments, 0) @@ -32,21 +23,14 @@ def test_adjust_rotation_increment_360( def test_adjust_rotation_increment_5( - rotation_ticks, - start_end_times, + derotation_pipeline: DerotationPipeline, ): - start, end = start_end_times - rot_deg = 5 + derotation_pipeline.rot_deg = 5 ( new_increments, - ticks_per_rotation, - ) = DerotationPipeline.adjust_rotation_increment( - rotation_ticks, - start, - end, - rot_deg, - ) + _, + ) = derotation_pipeline.adjust_rotation_increment() new_increments = np.round(new_increments, 3) diff --git a/tests/test_unit/test_create_signed_rotation_array.py b/tests/test_unit/test_create_signed_rotation_array.py index 4961697..b5329e6 100644 --- a/tests/test_unit/test_create_signed_rotation_array.py +++ b/tests/test_unit/test_create_signed_rotation_array.py @@ -4,15 +4,11 @@ def test_create_signed_rotation_array_interleaved( - full_length, start_end_times, direction_interleaved + derotation_pipeline: DerotationPipeline, + start_end_times: tuple, ): start, end = start_end_times - rotation_on = DerotationPipeline.create_signed_rotation_array( - full_length, - start, - end, - direction_interleaved, - ) + rotation_on = derotation_pipeline.create_signed_rotation_array() for idx in range(0, len(start), 2): assert np.all(rotation_on[start[idx] : end[idx]] == 1) @@ -20,15 +16,13 @@ def test_create_signed_rotation_array_interleaved( def test_create_signed_rotation_array_incremental( - full_length, start_end_times, direction_incremental + derotation_pipeline: DerotationPipeline, + start_end_times: tuple, + direction_incremental: np.ndarray, ): + derotation_pipeline.direction = direction_incremental start, end = start_end_times - rotation_on = DerotationPipeline.create_signed_rotation_array( - full_length, - start, - end, - direction_incremental, - ) + rotation_on = derotation_pipeline.create_signed_rotation_array() for idx in range(0, 5): assert np.all(rotation_on[start[idx] : end[idx]] == 1) diff --git a/tests/test_unit/test_drop_ticks.py b/tests/test_unit/test_drop_ticks.py index 5534d1e..5f50de2 100644 --- a/tests/test_unit/test_drop_ticks.py +++ b/tests/test_unit/test_drop_ticks.py @@ -2,11 +2,8 @@ def test_drop_ticks_generated_randomly( - rotation_ticks, start_end_times, full_length, number_of_rotations + derotation_pipeline: DerotationPipeline, ): - start, end = start_end_times - cleaned_ticks = DerotationPipeline.drop_ticks_outside_of_rotation( - rotation_ticks, start, end, full_length, number_of_rotations - ) + derotation_pipeline.drop_ticks_outside_of_rotation() - assert len(cleaned_ticks) == 362 + assert len(derotation_pipeline.rotation_ticks_peaks) == 362 diff --git a/tests/test_unit/test_finding_correct_start_end_times_with_threshold.py b/tests/test_unit/test_finding_correct_start_end_times_with_threshold.py index d5c23fc..7c1e94c 100644 --- a/tests/test_unit/test_finding_correct_start_end_times_with_threshold.py +++ b/tests/test_unit/test_finding_correct_start_end_times_with_threshold.py @@ -1,10 +1,16 @@ +import numpy as np + from derotation.analysis.derotation_pipeline import DerotationPipeline def test_finding_correct_start_end_times_with_threshold( - full_rotation, k, rotation_len, number_of_rotations + derotation_pipeline: DerotationPipeline, + full_rotation: np.ndarray, + k: int, + number_of_rotations: int, + rotation_len: int, ): - start, end = DerotationPipeline.get_start_end_times_with_threshold( + start, end = derotation_pipeline.get_start_end_times_with_threshold( full_rotation, k ) diff --git a/tests/test_unit/test_removing_brief_off_periods.py b/tests/test_unit/test_removing_brief_off_periods.py index dc23c41..62d8491 100644 --- a/tests/test_unit/test_removing_brief_off_periods.py +++ b/tests/test_unit/test_removing_brief_off_periods.py @@ -4,17 +4,15 @@ def test_removing_brief_off_periods( - start_end_times, - start_end_times_with_bug, + start_end_times: tuple, + start_end_times_with_bug: tuple, + derotation_pipeline: DerotationPipeline, ): - inter_rotation_interval_min_len = 50 start_buggy, end_buggy = start_end_times_with_bug start, end = start_end_times - corrected = DerotationPipeline.correct_start_and_end_rotation_signal( - inter_rotation_interval_min_len, - start_buggy, - end_buggy, + corrected = derotation_pipeline.correct_start_and_end_rotation_signal( + start_buggy, end_buggy ) assert len(corrected["start"]) == len(corrected["end"]) @@ -25,5 +23,5 @@ def test_removing_brief_off_periods( assert np.any( corrected["end"] - corrected["start"] - >= inter_rotation_interval_min_len + >= derotation_pipeline.inter_rotation_interval_min_len )