Skip to content

Commit

Permalink
check correctness of get_preferred_songs function
Browse files Browse the repository at this point in the history
  • Loading branch information
BeritJanssen committed Dec 20, 2023
1 parent 4da380d commit 8098d97
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
8 changes: 4 additions & 4 deletions backend/experiment/rules/musical_preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,11 @@ def feedback_info(self):
return info

def get_preferred_songs(self, result_set, n=5):
top_songs = result_set.values('section').annotate(
avg_score=Avg('score')).order_by('-score')[:n]
top_results = result_set.annotate(
avg_score=Avg('score')).order_by('score')[:n]
out_list = []
for s in top_songs:
section = Section.objects.get(pk=s.get('section'))
for result in top_results.all():
section = Section.objects.get(pk=result.section.id)
out_list.append({'artist': section.song.artist,
'name': section.song.name})
return out_list
46 changes: 46 additions & 0 deletions backend/experiment/rules/tests/test_musical_preferences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from django.test import TestCase
from django.db.models import Avg

from experiment.rules.musical_preferences import MusicalPreferences

from experiment.models import Experiment
from participant.models import Participant
from result.models import Result
from section.models import Playlist
from session.models import Session

class MusicalPreferencesTest(TestCase):
fixtures = ['playlist', 'experiment']

@classmethod
def setUpTestData(cls):
cls.participant = Participant.objects.create()
cls.playlist = Playlist.objects.create(name='MusicalPrefences')
csv = ("SuperArtist,SuperSong,0.0,10.0,bat/artist1.mp3,0,0,0\n"
"SuperArtist,MehSong,0.0,10.0,bat/artist2.mp3,0,0,0\n"
"MehArtist,MehSong,0.0,10.0,bat/artist3.mp3,0,0,0\n"
"AwfulArtist,MehSong,0.0,10.0,bat/artist4.mp3,0,0,0\n"
"AwfulArtist,AwfulSong,0.0,10.0,bat/artist5.mp3,0,0,0\n")
cls.playlist.csv = csv
cls.playlist.update_sections()
cls.experiment = Experiment.objects.create(name='MusicalPreferences', rounds=5)
cls.session = Session.objects.create(
experiment=cls.experiment,
participant=cls.participant,
playlist=cls.playlist
)

def test_preferred_songs(self):
for index, section in enumerate(list(self.playlist.section_set.all())):
Result.objects.create(
question_key='like_song',
score=5-index,
section=section,
session=self.session
)
mp = MusicalPreferences()
preferred_sections = mp.get_preferred_songs(self.session.result_set, 3)
assert preferred_sections[0]['artist'] == 'SuperArtist'
assert preferred_sections[1]['name'] == 'MehSong'
assert preferred_sections[2]['artist'] == 'MehArtist'
assert 'AwfulArtist' not in [p['artist'] for p in preferred_sections]

0 comments on commit 8098d97

Please sign in to comment.