Skip to content

Commit

Permalink
Reader (#25)
Browse files Browse the repository at this point in the history
* train-test-queries

* update files and add license

* add gold passages

* update gold passage and readme

* changes

* add bemba gold passages

* update readme

* additional gold passages

* add swahili test gold span
  • Loading branch information
ToluClassics authored Jul 8, 2023
1 parent 8a2486b commit 1a98d76
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 87 deletions.
19 changes: 15 additions & 4 deletions baselines/reader/train_seq_2_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,15 @@ def preprocess_squad_batch(
def generate_input(_question, _context):
return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()])

inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
inputs = []
for question, context in zip(questions, contexts):
try:
a = generate_input(question, context)
inputs.append(a)
except Exception as e:
continue


targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets

Expand Down Expand Up @@ -601,8 +609,11 @@ def post_processing_function(
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]
try:
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]
except KeyError:
continue

# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
Expand All @@ -612,7 +623,7 @@ def post_processing_function(
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples]
references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples if ex["context"] is not None]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

# Initialize our Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,3 @@
{"answer_pivot": {"answer_start": [-1], "text": ["yes"]}, "context": "The military history of South Africa chronicles a vast time period and complex events from the dawn of history until the present time. It covers civil wars and wars of aggression and of self-defence both within South Africa and against it. It includes the history of battles fought in the territories of modern South Africa in neighbouring territories, in both world wars and in modern international conflicts. Prehistory.", "id": "232", "question_lang": "Bushe icalo ca South afrika calilwapo inkondo?", "question_translated": "Has South africa ever fought a war?", "title": "Military history of South Africa", "answer_lang": "Emukwai"}
{"answer_pivot": {"answer_start": [241], "text": ["1964"]}, "context": "Lusaka lost some of its status to Salisbury (now Harare in Zimbabwe) when the latter became the capital of the merged Federation of Rhodesia and Nyasaland in 1953, but regained it when it was named the capital of newly independent Zambia in 1964.", "id": "233", "question_lang": "Nililali Lusaka yasalilwe ukuba umusumba uukalamba umwikala kateka?", "question_translated": "In which year was Lusaka chosen as capital of zambia?", "title": "Lusaka", "answer_lang": "1964"}
{"answer_pivot": {"answer_start": [0], "text": ["Southern Province"]}, "context": "Southern Province and Eastern Province are the two primary breadbaskets of Zambia. Southern Province produces more than 600,000 metric tons of maize each year from a combination of commercial, which are unique to Southern Province, and smallholder farms. Despite poor rains in recent years and a strong El Nino weather cycle in 2016, Zambian maize output has been predicted to continue to grow.", "id": "234", "question_lang": "Bushe citungu nshi mu Zambia eko balima imbuto ya mataba iyingi?", "question_translated": "Which province in Zambia is the main producer of maize?", "title": " Southern Province,", "answer_lang": "Southern Province"}

Large diffs are not rendered by default.

295 changes: 295 additions & 0 deletions data/gold_passages/swa/gold_span_passages.afriqa.swa.en.test.json

Large diffs are not rendered by default.

Large diffs are not rendered by default.

49 changes: 0 additions & 49 deletions scripts/generate_translation_gold_span_file.py

This file was deleted.

26 changes: 14 additions & 12 deletions scripts/reader_generative_qa.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
translation_type=$1
data_file_path=$2
model_name_or_path=$3

#==================================================================================================#
# Generative QA inference
#==================================================================================================#

for lang in bem fon hau ibo kin swa twi wol yor zul
# #==================================================================================================#
# # Generative QA inference
# #==================================================================================================#

for lang in bem yor zul hau ibo kin twi swa fon
do
for split in test
do

model_name_or_path=Atnafu/generative_reader_nq_squad_v2
validation_file=${data_file_path}/${split}.${lang}.${translation_type}.json
model_name_or_path=$model_name_or_path
# validation_file=${data_file_path}/${split}.${lang}.${translation_type}.json
validation_file=data/gold_passages/${lang}/gold_span_passages.afriqa.${lang}.en.${split}.json
output_dir=models
batch_size=8
num_train_epochs=10
Expand Down Expand Up @@ -39,18 +42,17 @@ do
done
done

#==================================================================================================#
# Multingual Generative QA using a finetuned mt5-base using the in-language queries
#==================================================================================================#
# #==================================================================================================#
# # Multingual Generative QA using a finetuned mt5-base using the in-language queries
# #==================================================================================================#


for lang in bem fon hau ibo kin swa twi wol yor zul
for lang in bem yor zul hau ibo kin twi swa fon
do
for split in test
do

model_name_or_path=Atnafu/generative_reader_nq_squad_v2
validation_file=${data_file_path}/${split}.${lang}.${translation_type}.json
validation_file=data/gold_passages/${lang}/gold_span_passages.afriqa.${lang}.en.${split}.json
output_dir=models
batch_size=8
num_train_epochs=10
Expand Down
3 changes: 2 additions & 1 deletion scripts/train_reader_extractive.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@

translation_type=$1
data_file_path=$2
model_name_or_path=$3

for lang in kin
do
for split in test
do
model_name_or_path=ToluClassics/extractive_reader_afroxlmr_squad_v2
model_name_or_path=$model_name_or_path
validation_file=${data_file_path}/${split}.${lang}.${translation_type}.json
output_dir=models
batch_size=16
Expand Down
8 changes: 1 addition & 7 deletions scripts/train_reader_generative.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

model_name_or_path=Atnafu/mt5-base-squad2-fin
# dataset_name=Tevatron/wikipedia-nq+
model_name_or_path=$1
dataset_name=squad_v2
output_dir=models
batch_size= 32
Expand All @@ -13,21 +11,17 @@ CUDA_VISIBLE_DEVICES=4 python3 baselines/reader/train_seq_2_seq.py \
--dataset_name $dataset_name \
--do_train \
--do_eval \
--per_device_train_batch_size $batch_size \
--per_device_eval_batch_size $batch_size \
--learning_rate 3e-5 \
--num_train_epochs $num_train_epochs \
--max_seq_length $max_seq_length \
--doc_stride 128 \
--output_dir $output_dir/$model_name_or_path \
--save_steps $save_steps \
--overwrite_output_dir \
--push_to_hub \
--context_column context \
--predict_with_generate True \
--question_column question \
--answer_column answers \
--push_to_hub_model_id=extractive_reader_nq_squad_v2 \
--weight_decay 0.01 \
--eval_steps 1000 \
--logging_steps 1000 \
Expand Down

0 comments on commit 1a98d76

Please sign in to comment.