diff --git a/DL/recommender_system/.gitignore b/DL/recommender_system/.gitignore new file mode 100644 index 0000000..5084e53 --- /dev/null +++ b/DL/recommender_system/.gitignore @@ -0,0 +1,3 @@ +*.csv +.ipynb_checkpoints +*.pth diff --git a/DL/recommender_system/nn.ipynb b/DL/recommender_system/nn.ipynb new file mode 100644 index 0000000..0c8f839 --- /dev/null +++ b/DL/recommender_system/nn.ipynb @@ -0,0 +1,1954 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "29b5381a-72b4-4f81-9208-2075f7acad85", + "metadata": {}, + "source": [ + "# Recommender System using Neural Network" + ] + }, + { + "cell_type": "markdown", + "id": "cac03d32", + "metadata": {}, + "source": [ + "Configure the project. Indeed you create a dataset in csv format." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6f480cda-8380-4355-998a-5c59d6203b05", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archive: ./dataset/archive.zip\n", + " inflating: anime.csv \n", + " inflating: animelist.csv \n", + " inflating: anime_with_synopsis.csv \n", + " inflating: rating_complete.csv \n" + ] + } + ], + "source": [ + "! rm -rf *.csv\n", + "! unzip ./dataset/archive.zip\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "db1f7e48", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 2.7G\n", + "-rw-r--r-- 1 andre andre 5.5M Jul 13 2021 anime.csv\n", + "-rw-r--r-- 1 andre andre 6.9M Jul 13 2021 anime_with_synopsis.csv\n", + "-rw-r--r-- 1 andre andre 1.9G Jul 13 2021 animelist.csv\n", + "-rw-r--r-- 1 andre andre 781M Jul 13 2021 rating_complete.csv\n", + "drwxr-xr-x 6 andre andre 4.0K Sep 14 12:48 ..\n", + "drwxr-xr-x 2 andre andre 4.0K Sep 25 18:58 dataset\n", + "-rw-r--r-- 1 andre andre 198K Sep 26 15:26 best_anime_model.pth\n", + "-rw-r--r-- 1 andre andre 11K Sep 27 00:50 nn.py\n", + "-rw-r--r-- 1 andre andre 195K Sep 27 01:09 best_model.pth\n", + "drwxr-xr-x 3 andre andre 4.0K Sep 27 01:11 .\n", + "-rw-r--r-- 1 andre andre 193K Sep 27 01:11 nn.ipynb\n" + ] + } + ], + "source": [ + "! ls -ltrha" + ] + }, + { + "cell_type": "markdown", + "id": "52ec2f48", + "metadata": {}, + "source": [ + "Import needed libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd17f780", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.model_selection import train_test_split\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import tqdm\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "aec5c99e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/andre/miniconda3/envs/venv/lib/python3.12/site-packages/torch/cuda/__init__.py:128: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", + " return torch._C._cuda_getDeviceCount() > 0\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f'Using device: {device}')" + ] + }, + { + "cell_type": "markdown", + "id": "57b33a77", + "metadata": {}, + "source": [ + "Read data from csv files using pandas and store in data frame structure. Also shuffle data to have uniform distribution. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a102a751", + "metadata": {}, + "outputs": [], + "source": [ + "anime_df = pd.read_csv(\"anime.csv\")\n", + "anime_df = anime_df.sample(frac=1.0, random_state=42).reset_index(drop=True)\n", + "\n", + "anime_synopsis_df = pd.read_csv(\"anime_with_synopsis.csv\")\n", + "anime_synopsis_df = anime_synopsis_df.sample(frac=1.0, random_state=42).reset_index(drop=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "639afef8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAL_IDNameScoreGenresEnglish nameJapanese nameTypeEpisodesAiredPremiered...Score-10Score-9Score-8Score-7Score-6Score-5Score-4Score-3Score-2Score-1
040176Miru Tights: Cosplay Satsuei Tights6.53Ecchi, SchoolUnknownみるタイツ コスプレ撮影 タイツSpecial1Aug 23, 2019Unknown...875.0350.0762.01526.01542.0924.0384.0245.0162.0148.0
113969Thermae Romae x Yoyogi Animation Gakuin Collab...6.29Comedy, Historical, SeinenUnknownテルマエ・ロマエx代々木アニメーション学院企業コラボレーションSpecial1Jul 9, 2012Unknown...35.047.0114.0253.0240.0162.063.029.010.010.0
213459Ribbon-chanUnknownComedyUnknownリボンちゃんTV24Apr 4, 2012 to Mar 27, 2013Spring 2012...7.0UnknownUnknown2.02.04.01.0Unknown2.07.0
315617Jinrui wa Suitai Shimashita Specials7.23Comedy, Fantasy, SeinenHumanity Has Declined Specials人類は衰退しましたSpecial6Sep 19, 2012 to Feb 20, 2013Unknown...451.0885.02432.03038.01388.0588.0130.038.022.019.0
419157Youkai Watch6.54Comedy, Demons, Kids, SupernaturalYo-kai Watch妖怪ウォッチTV214Jan 8, 2014 to Mar 30, 2018Winter 2014...517.0532.01141.01912.01636.01196.0500.0228.0138.0125.0
..................................................................
1755732238Watashi wa, Kairaku Izonshou6.2HentaiUnknown私は、快楽依存症OVA2Feb 26, 2016 to May 20, 2016Unknown...117.096.0197.0329.0363.0216.0110.049.052.052.0
1755833552Mameshiba Bangai-hen5.75Music, ComedyUnknown豆しば番外編Special52008 to Jun 20, 2019Unknown...8.02.06.08.024.047.011.06.02.012.0
175598476Otome Youkai Zakuro7.47Demons, Historical, Military, Romance, Seinen,...Zakuroおとめ妖怪 ざくろTV13Oct 5, 2010 to Dec 28, 2010Fall 2010...3237.05815.012079.012757.05674.02383.0688.0234.099.081.0
17560953Jyu Oh Sei7.26Action, Sci-Fi, Adventure, Mystery, Drama, ShoujoJyu-Oh-Sei:Planet of the Beast King獣王星TV11Apr 14, 2006 to Jun 23, 2006Spring 2006...2193.03886.07188.08062.04360.02140.0934.0302.0172.0148.0
1756139769Kimi ni Sekai6.7Sci-Fi, Music, FantasyUnknown君に世界Music1Apr 20, 2019Unknown...48.042.072.0181.0134.064.017.013.07.04.0
\n", + "

17562 rows × 35 columns

\n", + "
" + ], + "text/plain": [ + " MAL_ID Name Score \\\n", + "0 40176 Miru Tights: Cosplay Satsuei Tights 6.53 \n", + "1 13969 Thermae Romae x Yoyogi Animation Gakuin Collab... 6.29 \n", + "2 13459 Ribbon-chan Unknown \n", + "3 15617 Jinrui wa Suitai Shimashita Specials 7.23 \n", + "4 19157 Youkai Watch 6.54 \n", + "... ... ... ... \n", + "17557 32238 Watashi wa, Kairaku Izonshou 6.2 \n", + "17558 33552 Mameshiba Bangai-hen 5.75 \n", + "17559 8476 Otome Youkai Zakuro 7.47 \n", + "17560 953 Jyu Oh Sei 7.26 \n", + "17561 39769 Kimi ni Sekai 6.7 \n", + "\n", + " Genres \\\n", + "0 Ecchi, School \n", + "1 Comedy, Historical, Seinen \n", + "2 Comedy \n", + "3 Comedy, Fantasy, Seinen \n", + "4 Comedy, Demons, Kids, Supernatural \n", + "... ... \n", + "17557 Hentai \n", + "17558 Music, Comedy \n", + "17559 Demons, Historical, Military, Romance, Seinen,... \n", + "17560 Action, Sci-Fi, Adventure, Mystery, Drama, Shoujo \n", + "17561 Sci-Fi, Music, Fantasy \n", + "\n", + " English name Japanese name \\\n", + "0 Unknown みるタイツ コスプレ撮影 タイツ \n", + "1 Unknown テルマエ・ロマエx代々木アニメーション学院企業コラボレーション \n", + "2 Unknown リボンちゃん \n", + "3 Humanity Has Declined Specials 人類は衰退しました \n", + "4 Yo-kai Watch 妖怪ウォッチ \n", + "... ... ... \n", + "17557 Unknown 私は、快楽依存症 \n", + "17558 Unknown 豆しば番外編 \n", + "17559 Zakuro おとめ妖怪 ざくろ \n", + "17560 Jyu-Oh-Sei:Planet of the Beast King 獣王星 \n", + "17561 Unknown 君に世界 \n", + "\n", + " Type Episodes Aired Premiered ... \\\n", + "0 Special 1 Aug 23, 2019 Unknown ... \n", + "1 Special 1 Jul 9, 2012 Unknown ... \n", + "2 TV 24 Apr 4, 2012 to Mar 27, 2013 Spring 2012 ... \n", + "3 Special 6 Sep 19, 2012 to Feb 20, 2013 Unknown ... \n", + "4 TV 214 Jan 8, 2014 to Mar 30, 2018 Winter 2014 ... \n", + "... ... ... ... ... ... \n", + "17557 OVA 2 Feb 26, 2016 to May 20, 2016 Unknown ... \n", + "17558 Special 5 2008 to Jun 20, 2019 Unknown ... \n", + "17559 TV 13 Oct 5, 2010 to Dec 28, 2010 Fall 2010 ... \n", + "17560 TV 11 Apr 14, 2006 to Jun 23, 2006 Spring 2006 ... \n", + "17561 Music 1 Apr 20, 2019 Unknown ... \n", + "\n", + " Score-10 Score-9 Score-8 Score-7 Score-6 Score-5 Score-4 Score-3 \\\n", + "0 875.0 350.0 762.0 1526.0 1542.0 924.0 384.0 245.0 \n", + "1 35.0 47.0 114.0 253.0 240.0 162.0 63.0 29.0 \n", + "2 7.0 Unknown Unknown 2.0 2.0 4.0 1.0 Unknown \n", + "3 451.0 885.0 2432.0 3038.0 1388.0 588.0 130.0 38.0 \n", + "4 517.0 532.0 1141.0 1912.0 1636.0 1196.0 500.0 228.0 \n", + "... ... ... ... ... ... ... ... ... \n", + "17557 117.0 96.0 197.0 329.0 363.0 216.0 110.0 49.0 \n", + "17558 8.0 2.0 6.0 8.0 24.0 47.0 11.0 6.0 \n", + "17559 3237.0 5815.0 12079.0 12757.0 5674.0 2383.0 688.0 234.0 \n", + "17560 2193.0 3886.0 7188.0 8062.0 4360.0 2140.0 934.0 302.0 \n", + "17561 48.0 42.0 72.0 181.0 134.0 64.0 17.0 13.0 \n", + "\n", + " Score-2 Score-1 \n", + "0 162.0 148.0 \n", + "1 10.0 10.0 \n", + "2 2.0 7.0 \n", + "3 22.0 19.0 \n", + "4 138.0 125.0 \n", + "... ... ... \n", + "17557 52.0 52.0 \n", + "17558 2.0 12.0 \n", + "17559 99.0 81.0 \n", + "17560 172.0 148.0 \n", + "17561 7.0 4.0 \n", + "\n", + "[17562 rows x 35 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "anime_df" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ac373668", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAL_IDNameScoreGenressypnopsis
01220Hoshizora Kiseki5.8Romance, Sci-FiKozue is a girl who loves astronomy, particula...
138407Ishii Hiroyuki x Saitou Souma Essay-shuuUnknownSpaceadvertisement for a photo and essay collection...
22705Bakusou Kyoudai Let's & Go6.78Adventure, Cars, Sports, ShounenBased on the manga by Tetsuhiro Koshita, Bakus...
318829Hello Kitty no Shiawase no Aoi HotaruUnknownKidsp to camping with Kitty and her class.
49014Kuttsukiboshi6.12Romance, Supernatural, Drama, Shoujo AiTo Kiiko Kawakami, there was nothing in the wo...
..................
1620938009Re:Stage! Dream Days♪6.69Music, School, Slice of Lifeana Shikimiya has just transferred into Mareho...
1621010348Fireball Charming6.8Sci-Fi3D computer animation about a female robot duc...
16211979Street Fighter Zero The Animation6.55Action, Drama, Martial Arts, Shounen, Super Poweru, the current Street Fighter champion, must o...
1621242826Seijo no Maryoku wa Bannou DesuUnknownRomance, FantasySei, a 20-year-old office worker, is whisked a...
1621320047Sakura Trick7.02Slice of Life, Comedy, Romance, School, Seinen...Having been best friends since middle school, ...
\n", + "

16214 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " MAL_ID Name Score \\\n", + "0 1220 Hoshizora Kiseki 5.8 \n", + "1 38407 Ishii Hiroyuki x Saitou Souma Essay-shuu Unknown \n", + "2 2705 Bakusou Kyoudai Let's & Go 6.78 \n", + "3 18829 Hello Kitty no Shiawase no Aoi Hotaru Unknown \n", + "4 9014 Kuttsukiboshi 6.12 \n", + "... ... ... ... \n", + "16209 38009 Re:Stage! Dream Days♪ 6.69 \n", + "16210 10348 Fireball Charming 6.8 \n", + "16211 979 Street Fighter Zero The Animation 6.55 \n", + "16212 42826 Seijo no Maryoku wa Bannou Desu Unknown \n", + "16213 20047 Sakura Trick 7.02 \n", + "\n", + " Genres \\\n", + "0 Romance, Sci-Fi \n", + "1 Space \n", + "2 Adventure, Cars, Sports, Shounen \n", + "3 Kids \n", + "4 Romance, Supernatural, Drama, Shoujo Ai \n", + "... ... \n", + "16209 Music, School, Slice of Life \n", + "16210 Sci-Fi \n", + "16211 Action, Drama, Martial Arts, Shounen, Super Power \n", + "16212 Romance, Fantasy \n", + "16213 Slice of Life, Comedy, Romance, School, Seinen... \n", + "\n", + " sypnopsis \n", + "0 Kozue is a girl who loves astronomy, particula... \n", + "1 advertisement for a photo and essay collection... \n", + "2 Based on the manga by Tetsuhiro Koshita, Bakus... \n", + "3 p to camping with Kitty and her class. \n", + "4 To Kiiko Kawakami, there was nothing in the wo... \n", + "... ... \n", + "16209 ana Shikimiya has just transferred into Mareho... \n", + "16210 3D computer animation about a female robot duc... \n", + "16211 u, the current Street Fighter champion, must o... \n", + "16212 Sei, a 20-year-old office worker, is whisked a... \n", + "16213 Having been best friends since middle school, ... \n", + "\n", + "[16214 rows x 5 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "anime_synopsis_df" + ] + }, + { + "cell_type": "markdown", + "id": "75a968d7", + "metadata": {}, + "source": [ + "Define the neural network" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ba2cb244", + "metadata": {}, + "outputs": [], + "source": [ + "class AnimeRecommendationNN(nn.Module):\n", + " def __init__(self, input_size):\n", + " super(AnimeRecommendationNN, self).__init__()\n", + " self.fc1 = nn.Linear(input_size, 256)\n", + " self.fc2 = nn.Linear(256, 128)\n", + " self.fc3 = nn.Linear(128, 64)\n", + " self.fc4 = nn.Linear(64, 32)\n", + " self.fc5 = nn.Linear(32, 1)\n", + " self.relu = nn.LeakyReLU(negative_slope=0.01)\n", + " self.dropout = nn.Dropout(0.3)\n", + " self.batch_norm1 = nn.BatchNorm1d(256)\n", + " self.batch_norm2 = nn.BatchNorm1d(128)\n", + " self.batch_norm3 = nn.BatchNorm1d(64)\n", + " self.batch_norm4 = nn.BatchNorm1d(32)\n", + "\n", + " def forward(self, x):\n", + " x = self.dropout(self.relu(self.batch_norm1(self.fc1(x))))\n", + " x = self.dropout(self.relu(self.batch_norm2(self.fc2(x))))\n", + " x = self.dropout(self.relu(self.batch_norm3(self.fc3(x))))\n", + " x = self.dropout(self.relu(self.batch_norm4(self.fc4(x))))\n", + " x = self.fc5(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "90f1197e", + "metadata": {}, + "source": [ + "Feature engineering" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "45c355f5", + "metadata": {}, + "outputs": [], + "source": [ + "def clean_numeric_column(series):\n", + " series = series.astype(str)\n", + " series = series.replace(['Unknown', ''], np.nan)\n", + " series = pd.to_numeric(series, errors='coerce')\n", + " return series" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "28654be7", + "metadata": {}, + "outputs": [], + "source": [ + "def engineer_features(df):\n", + " df = df.copy()\n", + " \n", + " df['genre_count'] = df['Genres'].str.count(',') + 1\n", + " df['synopsis_length'] = df['sypnopsis'].str.len().fillna(0)\n", + " \n", + " numeric_columns = ['Score', 'Episodes', 'Members', 'Popularity', 'Ranked']\n", + " \n", + " for col in numeric_columns:\n", + " df[col] = clean_numeric_column(df[col])\n", + " # Fill NaN with median without using inplace\n", + " median_value = df[col].median()\n", + " df[col] = df[col].fillna(median_value)\n", + " \n", + " return df[['MAL_ID', 'Score', 'Episodes', 'Members', 'Popularity', 'Ranked', 'genre_count', 'synopsis_length']]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "65ee1094", + "metadata": {}, + "outputs": [], + "source": [ + "anime_features = engineer_features(pd.merge(anime_df, anime_synopsis_df[['MAL_ID', 'sypnopsis']], on='MAL_ID', how='left'))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a35cb269", + "metadata": {}, + "outputs": [], + "source": [ + "non_numeric_check = anime_features.drop('MAL_ID', axis=1).select_dtypes(exclude=[np.number])\n", + "if not non_numeric_check.empty:\n", + " print(\"Warning: Non-numeric data found in columns:\", non_numeric_check.columns)\n", + " print(\"Sample of non-numeric data:\")\n", + " print(non_numeric_check.head())\n", + " raise ValueError(\"Please check your data preprocessing steps.\")" + ] + }, + { + "cell_type": "markdown", + "id": "56cc96a6", + "metadata": {}, + "source": [ + "Normalize features" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0072397d", + "metadata": {}, + "outputs": [], + "source": [ + "scaler = StandardScaler()\n", + "anime_features_scaled = scaler.fit_transform(anime_features.drop('MAL_ID', axis=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c8a66545", + "metadata": {}, + "outputs": [], + "source": [ + "X = anime_features_scaled\n", + "y = anime_features['Score'].values" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "9cc83cb0", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0fdea965", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))\n", + "val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val))\n", + "\n", + "batch_size = 512\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "f6e2b4c3", + "metadata": {}, + "source": [ + "Training" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "bc6fad2c", + "metadata": {}, + "outputs": [], + "source": [ + "input_size = X_train.shape[1]\n", + "model = AnimeRecommendationNN(input_size).to(device)\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\n", + "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, min_lr=1e-6)\n", + "criterion = nn.MSELoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3cf30295", + "metadata": {}, + "outputs": [], + "source": [ + "num_epochs = 100\n", + "patience = 15\n", + "best_val_loss = float('inf')\n", + "no_improve = 0\n", + "\n", + "train_losses = []\n", + "val_losses = []" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "de45215b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 2%|▏ | 2/100 [00:00<00:20, 4.73it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [1/100], Train Loss: 35.1908, Val Loss: 34.1363\n", + "Epoch [2/100], Train Loss: 28.4761, Val Loss: 26.1533\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 4%|▍ | 4/100 [00:00<00:14, 6.71it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [3/100], Train Loss: 22.8870, Val Loss: 20.3565\n", + "Epoch [4/100], Train Loss: 18.0055, Val Loss: 15.7408\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 6%|▌ | 6/100 [00:00<00:12, 7.81it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [5/100], Train Loss: 13.4245, Val Loss: 11.4652\n", + "Epoch [6/100], Train Loss: 9.5039, Val Loss: 7.3833\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 8%|▊ | 8/100 [00:01<00:12, 7.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [7/100], Train Loss: 6.0797, Val Loss: 4.0210\n", + "Epoch [8/100], Train Loss: 3.4364, Val Loss: 1.7301\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 10%|█ | 10/100 [00:01<00:11, 7.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [9/100], Train Loss: 1.8735, Val Loss: 0.3324\n", + "Epoch [10/100], Train Loss: 1.3259, Val Loss: 0.0896\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 12%|█▏ | 12/100 [00:01<00:10, 8.41it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [11/100], Train Loss: 1.1793, Val Loss: 0.0788\n", + "Epoch [12/100], Train Loss: 1.1421, Val Loss: 0.0493\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 14%|█▍ | 14/100 [00:01<00:09, 8.67it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [13/100], Train Loss: 1.0728, Val Loss: 0.0493\n", + "Epoch [14/100], Train Loss: 1.0055, Val Loss: 0.0373\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 15%|█▌ | 15/100 [00:02<00:09, 8.75it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [15/100], Train Loss: 0.9540, Val Loss: 0.0340\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 17%|█▋ | 17/100 [00:02<00:11, 7.42it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [16/100], Train Loss: 0.9143, Val Loss: 0.0470\n", + "Epoch [17/100], Train Loss: 0.8634, Val Loss: 0.0281\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 19%|█▉ | 19/100 [00:02<00:10, 7.40it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [18/100], Train Loss: 0.8607, Val Loss: 0.0362\n", + "Epoch [19/100], Train Loss: 0.8161, Val Loss: 0.0189\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 21%|██ | 21/100 [00:02<00:09, 7.97it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [20/100], Train Loss: 0.7852, Val Loss: 0.0166\n", + "Epoch [21/100], Train Loss: 0.7763, Val Loss: 0.0257\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 23%|██▎ | 23/100 [00:03<00:08, 8.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [22/100], Train Loss: 0.7352, Val Loss: 0.0183\n", + "Epoch [23/100], Train Loss: 0.7398, Val Loss: 0.0281\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 25%|██▌ | 25/100 [00:03<00:08, 9.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [24/100], Train Loss: 0.7096, Val Loss: 0.0287\n", + "Epoch [25/100], Train Loss: 0.7122, Val Loss: 0.0163\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 27%|██▋ | 27/100 [00:03<00:08, 8.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [26/100], Train Loss: 0.6918, Val Loss: 0.0232\n", + "Epoch [27/100], Train Loss: 0.6863, Val Loss: 0.0228\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 29%|██▉ | 29/100 [00:03<00:08, 8.61it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [28/100], Train Loss: 0.6635, Val Loss: 0.0170\n", + "Epoch [29/100], Train Loss: 0.6679, Val Loss: 0.0171\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 30%|███ | 30/100 [00:03<00:08, 8.75it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [30/100], Train Loss: 0.6446, Val Loss: 0.0272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 32%|███▏ | 32/100 [00:04<00:09, 7.44it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [31/100], Train Loss: 0.6492, Val Loss: 0.0337\n", + "Epoch [32/100], Train Loss: 0.6390, Val Loss: 0.0221\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 34%|███▍ | 34/100 [00:04<00:08, 8.14it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [33/100], Train Loss: 0.6460, Val Loss: 0.0324\n", + "Epoch [34/100], Train Loss: 0.6318, Val Loss: 0.0180\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 36%|███▌ | 36/100 [00:04<00:07, 8.85it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [35/100], Train Loss: 0.6359, Val Loss: 0.0235\n", + "Epoch [36/100], Train Loss: 0.6204, Val Loss: 0.0253\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 38%|███▊ | 38/100 [00:04<00:06, 9.30it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [37/100], Train Loss: 0.6167, Val Loss: 0.0258\n", + "Epoch [38/100], Train Loss: 0.6130, Val Loss: 0.0227\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 40%|████ | 40/100 [00:05<00:06, 9.36it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [39/100], Train Loss: 0.6214, Val Loss: 0.0161\n", + "Epoch [40/100], Train Loss: 0.6209, Val Loss: 0.0237\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 42%|████▏ | 42/100 [00:05<00:07, 8.24it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [41/100], Train Loss: 0.6044, Val Loss: 0.0238\n", + "Epoch [42/100], Train Loss: 0.6016, Val Loss: 0.0171\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 44%|████▍ | 44/100 [00:05<00:06, 8.61it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [43/100], Train Loss: 0.6160, Val Loss: 0.0218\n", + "Epoch [44/100], Train Loss: 0.6252, Val Loss: 0.0178\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 45%|████▌ | 45/100 [00:05<00:06, 8.85it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [45/100], Train Loss: 0.6161, Val Loss: 0.0247\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 47%|████▋ | 47/100 [00:05<00:07, 7.50it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [46/100], Train Loss: 0.5884, Val Loss: 0.0257\n", + "Epoch [47/100], Train Loss: 0.6001, Val Loss: 0.0157\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 49%|████▉ | 49/100 [00:06<00:06, 8.01it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [48/100], Train Loss: 0.6033, Val Loss: 0.0279\n", + "Epoch [49/100], Train Loss: 0.6040, Val Loss: 0.0180\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 51%|█████ | 51/100 [00:06<00:05, 8.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [50/100], Train Loss: 0.6089, Val Loss: 0.0159\n", + "Epoch [51/100], Train Loss: 0.5913, Val Loss: 0.0169\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 53%|█████▎ | 53/100 [00:06<00:05, 8.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [52/100], Train Loss: 0.6020, Val Loss: 0.0197\n", + "Epoch [53/100], Train Loss: 0.5963, Val Loss: 0.0158\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 55%|█████▌ | 55/100 [00:06<00:05, 8.78it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [54/100], Train Loss: 0.5993, Val Loss: 0.0205\n", + "Epoch [55/100], Train Loss: 0.5912, Val Loss: 0.0193\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 57%|█████▋ | 57/100 [00:07<00:04, 9.02it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [56/100], Train Loss: 0.5841, Val Loss: 0.0203\n", + "Epoch [57/100], Train Loss: 0.5937, Val Loss: 0.0202\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 59%|█████▉ | 59/100 [00:07<00:04, 9.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [58/100], Train Loss: 0.6112, Val Loss: 0.0150\n", + "Epoch [59/100], Train Loss: 0.5767, Val Loss: 0.0172\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 60%|██████ | 60/100 [00:07<00:05, 7.97it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [60/100], Train Loss: 0.5896, Val Loss: 0.0179\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 62%|██████▏ | 62/100 [00:07<00:05, 7.05it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [61/100], Train Loss: 0.5894, Val Loss: 0.0130\n", + "Epoch [62/100], Train Loss: 0.5865, Val Loss: 0.0131\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 64%|██████▍ | 64/100 [00:08<00:04, 8.05it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [63/100], Train Loss: 0.5908, Val Loss: 0.0180\n", + "Epoch [64/100], Train Loss: 0.5949, Val Loss: 0.0164\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 66%|██████▌ | 66/100 [00:08<00:03, 8.66it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [65/100], Train Loss: 0.5909, Val Loss: 0.0197\n", + "Epoch [66/100], Train Loss: 0.5723, Val Loss: 0.0174\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 68%|██████▊ | 68/100 [00:08<00:03, 8.79it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [67/100], Train Loss: 0.5814, Val Loss: 0.0222\n", + "Epoch [68/100], Train Loss: 0.5995, Val Loss: 0.0202\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 70%|███████ | 70/100 [00:08<00:03, 8.57it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [69/100], Train Loss: 0.5859, Val Loss: 0.0246\n", + "Epoch [70/100], Train Loss: 0.5812, Val Loss: 0.0132\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 72%|███████▏ | 72/100 [00:08<00:03, 8.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [71/100], Train Loss: 0.5821, Val Loss: 0.0183\n", + "Epoch [72/100], Train Loss: 0.5841, Val Loss: 0.0206\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 74%|███████▍ | 74/100 [00:09<00:02, 8.78it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [73/100], Train Loss: 0.5871, Val Loss: 0.0192\n", + "Epoch [74/100], Train Loss: 0.5782, Val Loss: 0.0155\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epochs: 75%|███████▌ | 75/100 [00:09<00:03, 8.00it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [75/100], Train Loss: 0.5961, Val Loss: 0.0175\n", + "Epoch [76/100], Train Loss: 0.5817, Val Loss: 0.0208\n", + "Early stopping triggered after 76 epochs\n", + "Training completed.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "/tmp/ipykernel_307086/1420535760.py:64: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " model.load_state_dict(torch.load('best_model.pth'))\n" + ] + }, + { + "data": { + "text/plain": [ + "AnimeRecommendationNN(\n", + " (fc1): Linear(in_features=7, out_features=256, bias=True)\n", + " (fc2): Linear(in_features=256, out_features=128, bias=True)\n", + " (fc3): Linear(in_features=128, out_features=64, bias=True)\n", + " (fc4): Linear(in_features=64, out_features=32, bias=True)\n", + " (fc5): Linear(in_features=32, out_features=1, bias=True)\n", + " (relu): LeakyReLU(negative_slope=0.01)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " (batch_norm1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (batch_norm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (batch_norm3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (batch_norm4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for epoch in tqdm(range(num_epochs), desc=\"Epochs\"):\n", + " model.train()\n", + " train_loss = 0\n", + " total_train_loss = 0\n", + " train_batch_count = 0\n", + "\n", + " for batch_X, batch_y in train_loader:\n", + " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(batch_X)\n", + " loss = criterion(outputs.squeeze(), batch_y)\n", + " loss.backward()\n", + " \n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", + " \n", + " optimizer.step()\n", + "\n", + " total_train_loss += loss.item()\n", + " train_batch_count += 1\n", + " train_loss += loss.item()\n", + "\n", + " train_loss /= len(train_loader)\n", + " train_losses.append(train_loss)\n", + "\n", + " avg_train_loss = total_train_loss / train_batch_count\n", + "\n", + " # Validation phase\n", + " model.eval()\n", + " total_val_loss = 0\n", + " val_batch_count = 0\n", + "\n", + " with torch.no_grad():\n", + " for batch_X, batch_y in val_loader:\n", + " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", + " outputs = model(batch_X)\n", + " val_loss = criterion(outputs.squeeze(), batch_y)\n", + " total_val_loss += val_loss.item()\n", + " val_batch_count += 1\n", + "\n", + " avg_val_loss = total_val_loss / val_batch_count\n", + " val_losses.append(avg_val_loss) # Changed from val_loss to avg_val_loss\n", + "\n", + " tqdm.write(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')\n", + " \n", + " # Learning rate scheduling\n", + " scheduler.step(avg_val_loss)\n", + " \n", + " # Early stopping\n", + " if avg_val_loss < best_val_loss:\n", + " best_val_loss = avg_val_loss\n", + " no_improve = 0\n", + " torch.save(model.state_dict(), 'best_model.pth')\n", + " else:\n", + " no_improve += 1\n", + " if no_improve == patience:\n", + " print(f\"Early stopping triggered after {epoch+1} epochs\")\n", + " break\n", + "\n", + "print(\"Training completed.\")\n", + "\n", + "# Load the best model for inference\n", + "model.load_state_dict(torch.load('best_model.pth'))\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0aaef025", + "metadata": {}, + "outputs": [], + "source": [ + "def predict_anime_score(anime_features_dict):\n", + " feature_order = anime_features.columns.tolist()\n", + " feature_order.remove('MAL_ID') # Remove MAL_ID as it's not used for prediction \n", + " features_list = [anime_features_dict.get(feature, 0) for feature in feature_order]\n", + " \n", + " model.eval()\n", + " with torch.no_grad():\n", + " features = torch.FloatTensor(scaler.transform([features_list])).to(device)\n", + " prediction = model(features)\n", + " return prediction.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "338fa304", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted score for the new anime: 1.48\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/andre/miniconda3/envs/venv/lib/python3.12/site-packages/sklearn/base.py:493: UserWarning: X does not have valid feature names, but StandardScaler was fitted with feature names\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "new_anime_features = {\n", + " 'Score': 0,\n", + " 'Episodes': 12,\n", + " 'Members': 1000000,\n", + " 'Popularity': 1000,\n", + " 'Ranked': 500,\n", + " 'genre_count': 3,\n", + " 'synopsis_length': 150\n", + "}\n", + "\n", + "predicted_score = predict_anime_score(new_anime_features)\n", + "print(f\"Predicted score for the new anime: {predicted_score:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "224e985a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation MSE: 0.0130\n", + "Validation RMSE: 0.1140\n" + ] + } + ], + "source": [ + "model.eval()\n", + "val_predictions = []\n", + "val_true_scores = []\n", + "\n", + "with torch.no_grad():\n", + " for batch_X, batch_y in val_loader:\n", + " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", + " outputs = model(batch_X)\n", + " val_predictions.extend(outputs.squeeze().cpu().numpy())\n", + " val_true_scores.extend(batch_y.cpu().numpy())\n", + "\n", + "val_predictions = np.array(val_predictions)\n", + "val_true_scores = np.array(val_true_scores)\n", + "\n", + "# Calculate MSE and RMSE\n", + "mse = np.mean((val_predictions - val_true_scores) ** 2)\n", + "rmse = np.sqrt(mse)\n", + "\n", + "print(f\"Validation MSE: {mse:.4f}\")\n", + "print(f\"Validation RMSE: {rmse:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "36e94621", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(10, 6))\n", + "plt.scatter(val_true_scores, val_predictions, alpha=0.5)\n", + "plt.plot([val_true_scores.min(), val_true_scores.max()], [val_true_scores.min(), val_true_scores.max()], 'r--', lw=2)\n", + "plt.xlabel(\"Actual Scores\")\n", + "plt.ylabel(\"Predicted Scores\")\n", + "plt.title(\"Predicted vs Actual Anime Scores\")\n", + "plt.show()\n", + "\n", + "plt.figure(figsize=(10, 6))\n", + "plt.plot(train_losses, label='Training Loss')\n", + "plt.plot(val_losses, label='Validation Loss')\n", + "plt.xlabel(\"Epochs\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.title(\"Training and Validation Loss\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "da17265a", + "metadata": {}, + "source": [ + "get result" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "21bdb5a9", + "metadata": {}, + "outputs": [], + "source": [ + "def get_recommendations(input_anime, top_n=10):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " input_with_id = pd.merge(input_anime, anime_df[['MAL_ID', 'Name']], left_on='Title', right_on='Name', how='left')\n", + " input_features = pd.merge(input_with_id, anime_features, on='MAL_ID', how='left')\n", + " feature_columns = ['Score', 'Episodes', 'Members', 'Popularity', 'Ranked', 'genre_count', 'synopsis_length']\n", + " input_features = input_features[feature_columns]\n", + "\n", + " for col in feature_columns:\n", + " if input_features[col].isnull().any():\n", + " input_features[col].fillna(input_features[col].median(), inplace=True)\n", + " \n", + " input_scaled = scaler.transform(input_features)\n", + " input_tensor = torch.FloatTensor(input_scaled).to(device)\n", + " predicted_ratings = model(input_tensor).cpu().numpy().flatten()\n", + "\n", + " input_anime['Predicted_Rating'] = predicted_ratings\n", + " recommendations = input_anime.sort_values('Predicted_Rating', ascending=False).head(top_n)\n", + " return recommendations" + ] + }, + { + "cell_type": "markdown", + "id": "31322025", + "metadata": {}, + "source": [ + "User input" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "27c34fa3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Title Rating\n", + "0 Boku dake ga Inai Machi 10.0\n", + "1 Violet Evergarden 9.5\n", + "2 Goblin Slayer 6.0\n", + "3 Berserk 8.0\n", + "4 Shingeki no Kyojin 7.0\n", + "5 Tokyo Ghoul 6.5\n", + "6 Orange 6.0\n", + "7 Death Parade 8.0\n", + "8 Death Note 7.5\n", + "9 Bungou Stray Dogs 7.5\n", + "10 Tenki no Ko 8.0\n", + "11 Kimi no Na wa. 8.0\n", + "12 Kimi no Suizou wo Tabetai 8.5\n", + "13 Mononoke Hime 7.5\n", + "14 Sen to Chihiro no Kamikakushi 7.5\n", + "15 Koe no Katachi 8.5\n", + "16 Ao Haru Ride 5.5\n", + "17 Toki wo Kakeru Shoujo 7.0\n", + "18 Another 7.5\n", + "19 Kimetsu no Yaiba 7.0\n", + "20 Shigatsu wa Kimi no Uso 8.0\n", + "21 Byousoku 5 Centimeter 6.0\n", + "22 Kokoro ga Sakebitagatterunda. 7.5\n", + "23 Schick x Evangelion 5.0\n" + ] + } + ], + "source": [ + "userInput = [\n", + " {'Title': 'Boku dake ga Inai Machi', 'Rating': 10.0},\n", + " {'Title': 'Violet Evergarden', 'Rating': 9.5},\n", + " {'Title': 'Goblin Slayer', 'Rating': 6.0},\n", + " {'Title': 'Berserk', 'Rating': 8.0},\n", + " {'Title': 'Shingeki no Kyojin', 'Rating': 7.0},\n", + " {'Title': 'Tokyo Ghoul', 'Rating': 6.5},\n", + " {'Title': 'Orange', 'Rating': 6.0},\n", + " {'Title': 'Death Parade', 'Rating': 8.0},\n", + " {'Title': 'Death Note', 'Rating': 7.5},\n", + " {'Title': 'Bungou Stray Dogs', 'Rating': 7.5},\n", + " {'Title': 'Tenki no Ko', 'Rating': 8.0},\n", + " {'Title': 'Kimi no Na wa.', 'Rating': 8.0},\n", + " {'Title': 'Kimi no Suizou wo Tabetai', 'Rating': 8.5},\n", + " {'Title': 'Mononoke Hime', 'Rating': 7.5},\n", + " {'Title': 'Sen to Chihiro no Kamikakushi', 'Rating': 7.5},\n", + " {'Title': 'Koe no Katachi', 'Rating': 8.5},\n", + " {'Title': 'Ao Haru Ride', 'Rating': 5.5},\n", + " {'Title': 'Toki wo Kakeru Shoujo', 'Rating': 7.0},\n", + " {'Title': 'Another', 'Rating': 7.5},\n", + " {'Title': 'Kimetsu no Yaiba', 'Rating': 7.0},\n", + " {'Title': 'Shigatsu wa Kimi no Uso', 'Rating': 8.0},\n", + " {'Title': 'Byousoku 5 Centimeter', 'Rating': 6.0},\n", + " {'Title': 'Kokoro ga Sakebitagatterunda.', 'Rating': 7.5},\n", + " {'Title': 'Schick x Evangelion', 'Rating': 5.0}\n", + "]\n", + "\n", + "inputAnime = pd.DataFrame(userInput)\n", + "print(inputAnime)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "e457680c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Top 10 Recommended Anime:\n", + " Title Predicted_Rating\n", + "15 Koe no Katachi 8.676968\n", + "11 Kimi no Na wa. 8.668777\n", + "8 Death Note 8.655032\n", + "4 Shingeki no Kyojin 8.560674\n", + "14 Sen to Chihiro no Kamikakushi 8.536744\n", + "20 Shigatsu wa Kimi no Uso 8.446907\n", + "13 Mononoke Hime 8.424243\n", + "1 Violet Evergarden 8.375561\n", + "19 Kimetsu no Yaiba 8.362001\n", + "12 Kimi no Suizou wo Tabetai 8.318801\n" + ] + } + ], + "source": [ + "recommendations = get_recommendations(inputAnime)\n", + "print(\"\\nTop 10 Recommended Anime:\")\n", + "print(recommendations[['Title', 'Predicted_Rating']])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}