Skip to content

Commit

Permalink
Added: Implement factorial design based on groups with numbers and va…
Browse files Browse the repository at this point in the history
…riants with uppercase letters (#940)

* feat: Implement factorial design based on groups with numbers and variants with uppercase letters

* refactor: Update error message in CongoSameDiff class

* feat: Refactor CongoSameDiff class to ensure consistent number of variants in each group
  • Loading branch information
drikusroor authored Apr 11, 2024
1 parent bbdd593 commit 24fb854
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 16 deletions.
101 changes: 93 additions & 8 deletions backend/experiment/rules/congosamediff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

import random
import re
import itertools
import string
from django.utils.translation import gettext_lazy as _
from experiment.actions.final import Final
from experiment.models import Experiment
Expand All @@ -13,19 +15,19 @@


class CongoSameDiff(Base):
''' A micro-PROMS inspired experiment that tests the participant's ability to distinguish between different sounds. '''
""" A micro-PROMS inspired experiment that tests the participant's ability to distinguish between different sounds. """
ID = 'CONGOSAMEDIFF'
contact_email = 'aml.tunetwins@gmail.com'

def __init__(self):
pass

def first_round(self, experiment: Experiment):
''' Provide the first rounds of the experiment,
""" Provide the first rounds of the experiment,
before session creation
The first_round must return at least one Info or Explainer action
Consent and Playlist are often desired, but optional
'''
"""

# Do a validity check on the experiment
self.validate(experiment)
Expand Down Expand Up @@ -100,23 +102,40 @@ def next_round(self, session: Session):
# group number of the trial to be played
group_number = next_round_number - practice_trials_count - 1

# load the non-practice trial variants for the group number
# load the non-practice group variants for the group number
real_trial_variants = session.playlist.section_set.exclude(
tag__contains='practice'
).filter(
group=group_number
)

# pick a variant from the variants randomly (#919)
variants_count = real_trial_variants.count()
random_variants_index = random.randint(0, variants_count - 1)
# patterns amount is the number of groups times the number of variants in each group
groups_amount = session.playlist.section_set.values('group').distinct().count()
variants_amount = real_trial_variants.count()
patterns = self.get_patterns(groups_amount, variants_amount)

# get the participant's group variant
participant_id = session.participant.participant_id_url
participant_group_variant = self.get_participant_group_variant(
participant_id,
group_number,
patterns
)

# get the index of the participant's group variant in the real_trial_variants
# aka the index of the variant whose tag matches the participant's group variant
real_trial_variants_list = list(real_trial_variants)
pattern_group_variants_index = [
i for i, variant in enumerate(real_trial_variants_list)
if variant.tag == participant_group_variant
][0]

# if the next_round_number is greater than the no. of practice trials,
# return a non-practice trial
return self.get_next_trial(
session,
real_trial_variants,
random_variants_index + 1,
pattern_group_variants_index + 1,
False
)

Expand Down Expand Up @@ -242,5 +261,71 @@ def validate(self, experiment: Experiment):
if not sections.exclude(tag__contains='practice').exists():
errors.append('At least one section should not have the tag "practice"')

# Every non-practice group should have the same number of variants
# that should be labeled with a single uppercase letter
groups = sections.values('group').distinct()
variants = sections.exclude(tag__contains='practice').values('tag')
unique_variants = set([variant['tag'] for variant in variants])
variants_count = len(unique_variants)
for group in groups:
group_variants = sections.filter(group=group['group']).exclude(tag__contains='practice').values('tag').distinct()

for variant in group_variants:
if not re.search(r'^[A-Z]$', variant['tag']):
errors.append(f'Group {group["group"]} should have variants with a single uppercase letter (A-Z), but has {variant["tag"]}')

if group_variants.count() != variants_count:
group_variants_stringified = ', '.join([variant['tag'] for variant in group_variants])
total_variants_stringified = ', '.join(unique_variants)
errors.append(f'Group {group["group"]} should have the same number of variants as the total amount of variants ({variants_count}; {total_variants_stringified}) but has {group_variants.count()} ({group_variants_stringified})')

if errors:
raise ValueError('The experiment playlist is not valid: \n- ' + '\n- '.join(errors))

def get_patterns(self, groups_amount: int, variants_amount: int) -> list:
"""
Generate patterns based on the given number of groups and variants.
Args:
groups_amount (int): The number of groups.
variants_amount (int): The number of variants.
Returns:
list: A list of all possible patterns generated using itertools.product.
For example, `[('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')]`
"""
patterns = []

# Generate variant labels (e.g., ['A', 'B', 'C'])
variants = list(string.ascii_uppercase)[:variants_amount]

# Generate all possible patterns using itertools.product
patterns = list(itertools.product(variants, repeat=groups_amount))

return patterns

def get_participant_group_variant(self, participant_id: int, group_number: int, patterns: list, ):
"""
Returns the variant for a participant based on their ID, patterns, and group number.
For example, if there are 2 groups and 2 variants, the patterns would be:
[('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')].
The participant ID is used to select a pattern from the list.
The group number is used to select the group variant from the chosen pattern.
Let's say the participant ID is 3 and the group number is 2.
The participant ID modulo the number of patterns (3 % 4 = 3) would select the pattern ('B', 'A').
The group number (2) would then select the second variant ('A') from the pattern ('B', 'A').
Parameters:
participant_id (int): The ID of the participant, which serves as an index to choose a pattern.
group_number (int): The group number, which serves as an index for the chosen pattern.
patterns (list): A list of patterns generated using get_patterns.
Returns:
The variant for the participant.
"""

patterns_index = int(participant_id) % len(patterns) - 1
group_index = group_number - 1

return patterns[patterns_index][group_index]
72 changes: 64 additions & 8 deletions backend/experiment/rules/tests/test_congosamediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def setUpTestData(self):
self.section_csv = (
"Dave,m1_contour_practice,0.0,20.0,samediff/melody_1_contour.wav,practice,1\n"
"Dave,m2_same_practice,0.0,20.0,samediff/melody_1_same.wav,practice,1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,'',1\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,'',1\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,'',1\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,'',1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,'',2\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,'',2\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,'',2\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,'',2\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,A,1\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,B,1\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,C,1\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,D,1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,A,2\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,B,2\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,C,2\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,D,2\n"
)
self.playlist = PlaylistModel.objects.create(name='CongoSameDiff')
self.playlist.csv = self.section_csv
Expand Down Expand Up @@ -215,3 +215,59 @@ def test_get_total_trials_count(self):
# practice trials + post-practice question + non-practice trials
# 2 + 1 + 2 = 5
assert total_trials_count == 5

def test_get_patterns(self):
congo_same_diff = CongoSameDiff()
patterns = congo_same_diff.get_patterns(3, 2)
patterns_length = len(patterns)

assert patterns_length == 8
assert patterns == [
('A', 'A', 'A'),
('A', 'A', 'B'),
('A', 'B', 'A'),
('A', 'B', 'B'),
('B', 'A', 'A'),
('B', 'A', 'B'),
('B', 'B', 'A'),
('B', 'B', 'B'),
]

def test_get_patterns_bigger(self):
congo_same_diff = CongoSameDiff()
patterns = congo_same_diff.get_patterns(4, 4)
patterns_length = len(patterns)

assert patterns_length == 256
assert patterns[0] == ('A', 'A', 'A', 'A')
assert patterns[255] == ('D', 'D', 'D', 'D')

def test_get_participant_group_variant(self):
congo_same_diff = CongoSameDiff()

patterns = [('A', 'A'), ('A', 'B'), ('B', 'A'), ('B', 'B')]

# Test participant ID 1 and group number 1
variant = congo_same_diff.get_participant_group_variant(1, 1, patterns)
assert variant == 'A'

# Test participant ID 1 and group number 2
variant = congo_same_diff.get_participant_group_variant(1, 2, patterns)
assert variant == 'A'

# Test participant ID 6 and group number 1
variant = congo_same_diff.get_participant_group_variant(6, 1, patterns)
assert variant == 'A'

# Test participant ID 6 and group number 2
variant = congo_same_diff.get_participant_group_variant(6, 2, patterns)
assert variant == 'B'

# Test participant ID 7 and group number 1
variant = congo_same_diff.get_participant_group_variant(7, 1, patterns)
assert variant == 'B'

# Test participant ID 7 and group number 2
variant = congo_same_diff.get_participant_group_variant(7, 2, patterns)
assert variant == 'A'

0 comments on commit 24fb854

Please sign in to comment.