diff --git a/parrot/parrot.py b/parrot/parrot.py index f8ec738..e232f56 100644 --- a/parrot/parrot.py +++ b/parrot/parrot.py @@ -1,6 +1,12 @@ class Parrot(): - def __init__(self, model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=False): + def __init__( + self, + model_tag="prithivida/parrot_paraphraser_on_T5", + adequacy_model="prithivida/parrot_adequacy_model", + fluency_model="prithivida/parrot_fluency_model", + diversity_model="paraphrase-distilroberta-base-v2", + use_gpu=False): from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM import pandas as pd @@ -9,9 +15,9 @@ def __init__(self, model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=Fals from parrot.filters import Diversity self.tokenizer = AutoTokenizer.from_pretrained(model_tag) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_tag) - self.adequacy_score = Adequacy() - self.fluency_score = Fluency() - self.diversity_score= Diversity() + self.adequacy_score = Adequacy(model_tag=adequacy_model) + self.fluency_score = Fluency(model_tag=fluency_model) + self.diversity_score= Diversity(model_tag=diversity_model) def rephrase(self, input_phrase, use_gpu=False, diversity_ranker="levenshtein", do_diverse=False, style=1, max_length=32, adequacy_threshold = 0.90, fluency_threshold = 0.90): if use_gpu: