Skip to content

Commit

Permalink
refactor: Add rotating patterns to musicality battery experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
drikusroor committed Apr 12, 2024
1 parent 0a483f8 commit e0ed746
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 92 deletions.
70 changes: 28 additions & 42 deletions backend/experiment/rules/congosamediff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import re
import itertools
import math
import string
from django.utils.translation import gettext_lazy as _
from experiment.actions.utils import final_action_with_optional_button
Expand Down Expand Up @@ -36,7 +36,7 @@ def first_round(self, experiment: Experiment):

# 2. Explainer
explainer = Explainer(
instruction='Welcome to this Same Diff experiment',
instruction='Welcome to this Musicality Battery experiment',
steps=[],
)

Expand Down Expand Up @@ -111,14 +111,14 @@ def next_round(self, session: Session):
# patterns amount is the number of groups times the number of variants in each group
groups_amount = session.playlist.section_set.values('group').distinct().count()
variants_amount = real_trial_variants.count()
patterns = self.get_patterns(groups_amount, variants_amount)

# get the participant's group variant
participant_id = session.participant.participant_id_url
participant_group_variant = self.get_participant_group_variant(
participant_id,
int(participant_id),
group_number,
patterns
groups_amount,
variants_amount
)

# get the index of the participant's group variant in the real_trial_variants
Expand Down Expand Up @@ -201,7 +201,7 @@ def get_next_trial(
)
form = Form([question])
playback = PlayButton([section], play_once=False)
experiment_name = session.experiment.name if session.experiment else 'SameDiff Experiment'
experiment_name = session.experiment.name if session.experiment else 'Musicality Battery Experiment'
view = Trial(
playback=playback,
feedback_form=form,
Expand Down Expand Up @@ -281,50 +281,36 @@ def validate(self, experiment: Experiment):
if errors:
raise ValueError('The experiment playlist is not valid: \n- ' + '\n- '.join(errors))

def get_patterns(self, groups_amount: int, variants_amount: int) -> list:
"""
Generate patterns based on the given number of groups and variants.
def get_participant_group_variant(self, participant_id: int, group_number: int, groups_amount: int, variants_amount: int) -> str:

Args:
groups_amount (int): The number of groups.
variants_amount (int): The number of variants.
if participant_id <= 0:
raise ValueError(f"Participant id ({participant_id}) should be larger than 0")

Returns:
list: A list of all possible patterns generated using itertools.product.
For example, `[('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')]`
"""
patterns = []
if group_number <= 0:
raise ValueError(f"Group number ({group_number}) should be larger than 0")

if groups_amount <= 0:
raise ValueError(f"Groups amount ({groups_amount}) should be larger than 0")

if variants_amount <= 0:
raise ValueError(f"Variants amount ({variants_amount}) should be larger than 0")

# Generate variant labels (e.g., ['A', 'B', 'C'])
variants = list(string.ascii_uppercase)[:variants_amount]

# Generate all possible patterns using itertools.product
patterns = list(itertools.product(variants, repeat=groups_amount))
total_patterns = len(variants)

return patterns
participant_index = participant_id - 1

def get_participant_group_variant(self, participant_id: int, group_number: int, patterns: list, ):
"""
Returns the variant for a participant based on their ID, patterns, and group number.
For example, if there are 2 groups and 2 variants, the patterns would be:
[('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')].
The participant ID is used to select a pattern from the list.
The group number is used to select the group variant from the chosen pattern.
Let's say the participant ID is 3 and the group number is 2.
The participant ID modulo the number of patterns (3 % 4 = 3) would select the pattern ('B', 'A').
The group number (2) would then select the second variant ('A') from the pattern ('B', 'A').
Parameters:
participant_id (int): The ID of the participant, which serves as an index to choose a pattern.
group_number (int): The group number, which serves as an index for the chosen pattern.
patterns (list): A list of patterns generated using get_patterns.
Returns:
The variant for the participant.
group_index = group_number - 1

"""
# Determine if the pattern should be reversed (every 4th, 5th, 6th participant)
reversed_pattern = participant_index % (variants_amount * 2) >= variants_amount

patterns_index = int(participant_id) % len(patterns) - 1
group_index = group_number - 1
# Calculate the participant's group variant
if reversed_pattern:
variant_index = (participant_index - group_index) % total_patterns
else:
variant_index = (participant_index + group_index) % total_patterns

return patterns[patterns_index][group_index]
return variants[variant_index]
149 changes: 99 additions & 50 deletions backend/experiment/rules/tests/test_congosamediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,58 +216,107 @@ def test_get_total_trials_count(self):
# 2 + 1 + 2 = 5
assert total_trials_count == 5

def test_get_patterns(self):
congo_same_diff = CongoSameDiff()
patterns = congo_same_diff.get_patterns(3, 2)
patterns_length = len(patterns)

assert patterns_length == 8
assert patterns == [
('A', 'A', 'A'),
('A', 'A', 'B'),
('A', 'B', 'A'),
('A', 'B', 'B'),
('B', 'A', 'A'),
('B', 'A', 'B'),
('B', 'B', 'A'),
('B', 'B', 'B'),
]

def test_get_patterns_bigger(self):
congo_same_diff = CongoSameDiff()
patterns = congo_same_diff.get_patterns(4, 4)
patterns_length = len(patterns)

assert patterns_length == 256
assert patterns[0] == ('A', 'A', 'A', 'A')
assert patterns[255] == ('D', 'D', 'D', 'D')

def test_get_participant_group_variant(self):
csd = CongoSameDiff()

# Test with small number of groups and variants
self.assertEqual(csd.get_participant_group_variant(1, 1, 2, 2), 'A')
self.assertEqual(csd.get_participant_group_variant(1, 2, 2, 2), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 1, 2, 2), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 2, 2, 2), 'A')

# Test with more variants than groups
self.assertEqual(csd.get_participant_group_variant(1, 1, 2, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(1, 2, 2, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 1, 2, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 2, 2, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(3, 1, 2, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(3, 2, 2, 3), 'A')

# Test for participant 1 to 6 to match the expected sequence and reverses
self.assertEqual(csd.get_participant_group_variant(1, 1, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(1, 2, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(1, 3, 3, 3), 'C')

self.assertEqual(csd.get_participant_group_variant(2, 1, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 2, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(2, 3, 3, 3), 'A')

self.assertEqual(csd.get_participant_group_variant(3, 1, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(3, 2, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(3, 3, 3, 3), 'B')

self.assertEqual(csd.get_participant_group_variant(4, 1, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(4, 2, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(4, 3, 3, 3), 'B')

self.assertEqual(csd.get_participant_group_variant(5, 1, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(5, 2, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(5, 3, 3, 3), 'C')

self.assertEqual(csd.get_participant_group_variant(6, 1, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(6, 2, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(6, 3, 3, 3), 'A')

# Test for participant 7 to 12 to match the expected sequence and reverses
self.assertEqual(csd.get_participant_group_variant(7, 1, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(7, 2, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(7, 3, 3, 3), 'C')

self.assertEqual(csd.get_participant_group_variant(8, 1, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(8, 2, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(8, 3, 3, 3), 'A')

self.assertEqual(csd.get_participant_group_variant(9, 1, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(9, 2, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(9, 3, 3, 3), 'B')

self.assertEqual(csd.get_participant_group_variant(10, 1, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(10, 2, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(10, 3, 3, 3), 'B')

self.assertEqual(csd.get_participant_group_variant(11, 1, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(11, 2, 3, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(11, 3, 3, 3), 'C')

self.assertEqual(csd.get_participant_group_variant(12, 1, 3, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(12, 2, 3, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(12, 3, 3, 3), 'A')

# Test with more groups than variants
self.assertEqual(csd.get_participant_group_variant(1, 1, 4, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(1, 2, 4, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(1, 3, 4, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(1, 4, 4, 3), 'A')

self.assertEqual(csd.get_participant_group_variant(2, 1, 4, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(2, 2, 4, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(2, 3, 4, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(2, 4, 4, 3), 'B')

self.assertEqual(csd.get_participant_group_variant(4, 1, 4, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(4, 2, 4, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(4, 3, 4, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(4, 4, 4, 3), 'A')

self.assertEqual(csd.get_participant_group_variant(7, 1, 4, 3), 'A')
self.assertEqual(csd.get_participant_group_variant(7, 2, 4, 3), 'B')
self.assertEqual(csd.get_participant_group_variant(7, 3, 4, 3), 'C')
self.assertEqual(csd.get_participant_group_variant(7, 4, 4, 3), 'A')

def test_edge_cases(self):
congo_same_diff = CongoSameDiff()

patterns = [('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')]

# Test participant ID 1 and group number 1
variant = congo_same_diff.get_participant_group_variant(1, 1, patterns)
assert variant == 'A'
# Test edge cases
self.assertEqual(congo_same_diff.get_participant_group_variant(1, 4, 4, 3), 'A') # Group number exceeds variants
self.assertEqual(congo_same_diff.get_participant_group_variant(12, 1, 2, 3), 'C') # Reversed, with fewer groups than variants

# Test participant ID 1 and group number 2
variant = congo_same_diff.get_participant_group_variant(1, 2, patterns)
assert variant == 'A'

# Test participant ID 6 and group number 1
variant = congo_same_diff.get_participant_group_variant(6, 1, patterns)
assert variant == 'A'

# Test participant ID 6 and group number 2
variant = congo_same_diff.get_participant_group_variant(6, 2, patterns)
assert variant == 'B'

# Test participant ID 7 and group number 1
variant = congo_same_diff.get_participant_group_variant(7, 1, patterns)
assert variant == 'B'
def test_invalid_parameters(self):
congo_same_diff = CongoSameDiff()

# Test participant ID 7 and group number 2
variant = congo_same_diff.get_participant_group_variant(7, 2, patterns)
assert variant == 'A'

# Test scenarios with invalid parameters (should raise exceptions or handle gracefully)
with self.assertRaises(ValueError): # Assuming your method raises ValueError for invalid inputs
congo_same_diff.get_participant_group_variant(-1, 1, 3, 3) # Negative participant ID
congo_same_diff.get_participant_group_variant(1, -1, 3, 3) # Negative group number
congo_same_diff.get_participant_group_variant(1, 1, -1, 3) # Negative groups amount
congo_same_diff.get_participant_group_variant(1, 1, 3, -1) # Negative variants amount

0 comments on commit e0ed746

Please sign in to comment.