Skip to content

Commit

Permalink
Added: Pick a random variant per trial for the CongoSameDiff Experime…
Browse files Browse the repository at this point in the history
…nt (#920)

* refactor: Improve getting practice trials count

* feat: Pick a random variant per trial based on the group number

* refactor: Refactor CongoSameDiff validation logic
  • Loading branch information
drikusroor authored Apr 5, 2024
1 parent e41bf1e commit 349903a
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 34 deletions.
89 changes: 62 additions & 27 deletions backend/experiment/rules/congosamediff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import random
import re
from django.utils.translation import gettext_lazy as _
from experiment.actions.final import Final
from experiment.models import Experiment
Expand Down Expand Up @@ -26,20 +28,7 @@ def first_round(self, experiment: Experiment):
'''

# Do a validity check on the experiment
# All sections need to have a group value
sections = experiment.playlists.first().section_set.all()
for section in sections:
if not section.group:
file_name = section.song.name if section.song else 'No name'
raise ValueError(f'Section {file_name} should have a group value')

# It also needs at least one section with the tag 'practice'
if not sections.filter(tag__contains='practice').exists():
raise ValueError('At least one section should have the tag "practice"')

# It should also contain at least one section without the tag 'practice'
if not sections.exclude(tag__contains='practice').exists():
raise ValueError('At least one section should not have the tag "practice"')
self.validate(experiment)

# 1. Playlist
playlist = Playlist(experiment.playlists.all())
Expand All @@ -59,8 +48,8 @@ def next_round(self, session: Session):

next_round_number = session.get_current_round()

# total number of trials
total_trials_count = session.playlist.section_set.count() + 1 # +1 for the post-practice round
# practice trials + post-practice question + non-practice trials
total_trials_count = self.get_total_trials_count(session)

practice_done = session.result_set.filter(
question_key='practice_done',
Expand All @@ -71,16 +60,14 @@ def next_round(self, session: Session):
if next_round_number > total_trials_count:
return self.get_final_round(session)

# count of practice rounds (excluding the post-practice round)
practice_trials_count = session.playlist.section_set.filter(
tag__contains='practice'
).count()

# load the practice trials
practice_trials_subset = session.playlist.section_set.filter(
tag__contains='practice'
)

# count of practice rounds (excluding the post-practice round)
practice_trials_count = practice_trials_subset.count()

# if the user hasn't completed the practice trials
# return the next practice trial
if next_round_number <= practice_trials_count:
Expand All @@ -99,7 +86,7 @@ def next_round(self, session: Session):
return self.get_next_trial(
session,
practice_trials_subset,
1,
1, # first practice trial
True
)

Expand All @@ -110,17 +97,26 @@ def next_round(self, session: Session):
if next_round_number == practice_trials_count + 1 and not practice_done:
return self.get_practice_done_view(session)

# load the non-practice trials
real_trials_subset = session.playlist.section_set.exclude(
# group number of the trial to be played
group_number = next_round_number - practice_trials_count - 1

# load the non-practice trial variants for the group number
real_trial_variants = session.playlist.section_set.exclude(
tag__contains='practice'
).filter(
group=group_number
)

# pick a variant from the variants randomly (#919)
variants_count = real_trial_variants.count()
random_variants_index = random.randint(0, variants_count - 1)

# if the next_round_number is greater than the no. of practice trials,
# return a non-practice trial
return self.get_next_trial(
session,
real_trials_subset,
next_round_number - practice_trials_count - 1,
real_trial_variants,
random_variants_index + 1,
False
)

Expand Down Expand Up @@ -167,7 +163,7 @@ def get_next_trial(
section_group = section.group if section.group else 'no_group'

# define a key, by which responses to this trial can be found in the database
key = f'samediff_trial_{section_group}'
key = f'samediff_trial_{section_group}_{section_name}'

question = ChoiceQuestion(
explainer=f'{practice_label} ({trial_index}/{subset_count}) | {section_name} | {section_tag} | {section_group}',
Expand Down Expand Up @@ -209,3 +205,42 @@ def get_final_round(self, session: Session):
session=session,
final_text=_('Thank you for participating!'),
)

def get_total_trials_count(self, session: Session):
practice_trials_subset = session.playlist.section_set.filter(
tag__contains='practice'
)
practice_trials_count = practice_trials_subset.count()
total_exp_variants = session.playlist.section_set.exclude(
tag__contains='practice'
)
total_unique_exp_trials_count = total_exp_variants.values('group').distinct().count()
total_trials_count = practice_trials_count + total_unique_exp_trials_count + 1
return total_trials_count

def validate(self, experiment: Experiment):

errors = []

# All sections need to have a group value
sections = experiment.playlists.first().section_set.all()
for section in sections:
file_name = section.song.name if section.song else 'No name'
# every section.group should consist of a number
regex_pattern = r'^\d+$'
if not section.group or not re.search(regex_pattern, section.group):
errors.append(f'Section {file_name} should have a group value containing only digits')
# the section song name should not be empty
if not section.song.name:
errors.append(f'Section {file_name} should have a name that will be used for the result key')

# It also needs at least one section with the tag 'practice'
if not sections.filter(tag__contains='practice').exists():
errors.append('At least one section should have the tag "practice"')

# It should also contain at least one section without the tag 'practice'
if not sections.exclude(tag__contains='practice').exists():
errors.append('At least one section should not have the tag "practice"')

if errors:
raise ValueError('The experiment playlist is not valid: \n- ' + '\n- '.join(errors))
78 changes: 71 additions & 7 deletions backend/experiment/rules/tests/test_congosamediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ class CongoSameDiffTest(TestCase):
@classmethod
def setUpTestData(self):
self.section_csv = (
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,practice_contour,m1\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,practice_interval,m1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,same,m1\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,scale,m1\n"
"Dave,m1_contour_practice,0.0,20.0,samediff/melody_1_contour.wav,practice,1\n"
"Dave,m2_same_practice,0.0,20.0,samediff/melody_1_same.wav,practice,1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,'',1\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,'',1\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,'',1\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,'',1\n"
"Dave,m1_same,0.0,20.0,samediff/melody_1_same.wav,'',2\n"
"Dave,m1_scale,0.0,20.0,samediff/melody_1_scale.wav,'',2\n"
"Dave,m1_contour,0.0,20.0,samediff/melody_1_contour.wav,'',2\n"
"Dave,m1_interval,0.0,20.0,samediff/melody_1_interval.wav,'',2\n"
)
self.playlist = PlaylistModel.objects.create(name='CongoSameDiff')
self.playlist.csv = self.section_csv
Expand Down Expand Up @@ -61,7 +67,7 @@ def test_next_round_final_round(self):
)

self.session.get_current_round = lambda: 6

final_action = congo_same_diff.next_round(self.session)

assert isinstance(final_action, Final)
Expand Down Expand Up @@ -118,6 +124,23 @@ def test_throw_exception_if_trial_without_group(self):
with self.assertRaisesRegex(ValueError, "Section no_group should have a group value"):
congo_same_diff.first_round(experiment)

def test_throw_exception_if_trial_group_not_int(self):
congo_same_diff = CongoSameDiff()
experiment = Experiment(id=1, name='CongoSameDiff', slug='congosamediff_first_round', rounds=4)
experiment.save()
playlist = PlaylistModel.objects.create(name='CongoSameDiff')
Section.objects.create(
playlist=playlist,
start_time=0.0,
duration=20.0,
song=Song.objects.create(artist='group_not_int', name='group_not_int'),
tag='practice_contour',
group='not_int_42'
)
experiment.playlists.set([playlist])
with self.assertRaisesRegex(ValueError, "Section group_not_int should have a group value containing only digits"):
congo_same_diff.first_round(experiment)

def test_throw_exception_if_no_practice_rounds(self):
congo_same_diff = CongoSameDiff()
experiment = Experiment(id=1, name='CongoSameDiff', slug='congosamediff_first_round', rounds=4)
Expand All @@ -129,7 +152,7 @@ def test_throw_exception_if_no_practice_rounds(self):
duration=20.0,
song=Song.objects.create(artist='no_practice', name='no_practice'),
tag='',
group='m1'
group='1'
)
experiment.playlists.set([playlist])
with self.assertRaisesRegex(ValueError, 'At least one section should have the tag "practice"'):
Expand All @@ -146,8 +169,49 @@ def test_throw_exception_if_no_normal_rounds(self):
duration=20.0,
song=Song.objects.create(artist='only_practice', name='only_practice'),
tag='practice_contour',
group='m1'
group='42'
)
experiment.playlists.set([playlist])
with self.assertRaisesRegex(ValueError, 'At least one section should not have the tag "practice"'):
congo_same_diff.first_round(experiment)

def test_throw_combined_exceptions_if_multiple_errors(self):
congo_same_diff = CongoSameDiff()
experiment = Experiment(id=1, name='CongoSameDiff', slug='congosamediff_first_round', rounds=4)
experiment.save()
playlist = PlaylistModel.objects.create(name='CongoSameDiff')
Section.objects.create(
playlist=playlist,
start_time=0.0,
duration=20.0,
song=Song.objects.create(artist='no_group', name='no_group'),
tag='practice_contour',
group=''
)
Section.objects.create(
playlist=playlist,
start_time=0.0,
duration=20.0,
song=Song.objects.create(artist='group_not_int', name='group_not_int'),
tag='practice_contour',
group='not_int_42'
)
Section.objects.create(
playlist=playlist,
start_time=0.0,
duration=20.0,
song=Song.objects.create(artist='only_practice', name='only_practice'),
tag='practice_contour',
group='42'
)
experiment.playlists.set([playlist])
with self.assertRaisesRegex(ValueError, "The experiment playlist is not valid: \n- Section group_not_int should have a group value containing only digits\n- Section no_group should have a group value containing only digits\n- At least one section should not have the tag \"practice\""):
congo_same_diff.first_round(experiment)

def test_get_total_trials_count(self):
congo_same_diff = CongoSameDiff()
total_trials_count = congo_same_diff.get_total_trials_count(self.session)

# practice trials + post-practice question + non-practice trials
# 2 + 1 + 2 = 5
assert total_trials_count == 5

0 comments on commit 349903a

Please sign in to comment.