-
Notifications
You must be signed in to change notification settings - Fork 0
/
sentimentmodel.py
84 lines (73 loc) · 2.88 KB
/
sentimentmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TrainingArguments,
Trainer,
)
from typing import List
import torch
import re
class SentimentModel:
def __init__(self, model_name: str = "oliverguhr/german-sentiment-bert"):
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model = self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.clean_chars = re.compile(r"[^A-Za-züöäÖÜÄß ]", re.MULTILINE)
self.clean_http_urls = re.compile(r"https*\S+", re.MULTILINE)
self.clean_at_mentions = re.compile(r"@\S+", re.MULTILINE)
def transfer_learning(self, train_dataset, eval_dataset):
# instantiate a TrainingArguments
training_args = TrainingArguments("transfer_learning")
# instantiate a Trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Finetune the model
trainer.train()
def predict_sentiment(self, texts: List[str]) -> List[str]:
texts = [self.clean_text(text) for text in texts]
# Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
# limit number of tokens to model's limitations (512)
input_ids = self.tokenizer.batch_encode_plus(
texts, padding=True, add_special_tokens=True, truncation=True
)
input_ids = torch.tensor(input_ids["input_ids"])
input_ids = input_ids.to(self.device)
with torch.no_grad():
logits = self.model(input_ids)
label_ids = torch.argmax(logits[0], axis=1)
labels = [
self.model.config.id2label[label_id] for label_id in label_ids.tolist()
]
return labels
def replace_numbers(self, text: str) -> str:
return (
text.replace("0", " null")
.replace("1", " eins")
.replace("2", " zwei")
.replace("3", " drei")
.replace("4", " vier")
.replace("5", " fünf")
.replace("6", " sechs")
.replace("7", " sieben")
.replace("8", " acht")
.replace("9", " neun")
)
def clean_text(self, text: str) -> str:
text = text.replace("\n", " ")
text = self.clean_http_urls.sub("", text)
text = self.clean_at_mentions.sub("", text)
text = self.replace_numbers(text)
text = self.clean_chars.sub("", text) # use only text chars
text = " ".join(
text.split()
) # substitute multiple whitespace with single whitespace
text = text.strip().lower()
return text