This is the source code for paper "CorpusBrain: Pre-train a Generative Retrieval Model for Knowledge-Intensive Language Tasks".
CorpusBrain is a pre-trained generative retrieval model, which could encode all information about the corpus in its parameters without the need of constructing additional index. Furthermore, CorpusBrain can dramatically simplify the search process and be optimized in an end-to-end manner by replacing the traditional multi-step search pipeline with a novel single-step generative model. We show that a strong generative retrieval model can be learned with a set of adequately designed pre-training tasks, and be adopted to improve a variety of downstream retrieval tasks with further fine-tuning.
The KILT knowledge source can be downloaded here: kilt_knowledgesource.json (34.76GiB). It is based on the 2019/08/01 Wikipedia dump.
The BPE prefix tree (trie) from KILT Wikipedia titles that is based on the 2019/08/01 Wikipedia dump can be downloaded here: kilt_titles_trie_dict.pkl. The trie contains ~5M titles and it is used to generate document identifiers for all the KILT experiments.
The BART-Large checkpoint can be downloaded here: bart.large.tar.gz
To construct the pre-training data, run the following command:
bash scripts/preprocess_corpus.sh
To tokenize and binarize the data as expected from fairseq
use:
bash scripts/preprocess_fairseq.sh $DATASET_PATH $MODEL_PATH
To pre-train the model, run the following command:
bash scripts/train.sh $DATASET_PATH $NAME
CorpusBrain have the same architecture with BART-Large
, and thus you can fine-tune CorpusBrain like BART on any downstream retrieval tasks by just replacing BART checkpoints with CorpusBrain's.
CorpusBrain can be fine-tuned to serve a variety of downstream retrieval tasks in KILT. To download KILT data, run the following command:
mkdir data
python scripts/download_all_kilt_data.py
python scripts/get_triviaqa_input.py
To convert KILT data to fairseq format, run the following command:
python scripts/convert_kilt_to_fairseq_genre.py $input_filename $output_path
These steps are the same as the pre-training steps above.
After importing and loading the model and a prefix tree (trie), you would generate predictions with a simple call like:
import pickle
from genre.fairseq_model import GENRE
from genre.trie import Trie
# Load the prefix tree (trie)
with open("../data/kilt_titles_trie_dict.pkl", "rb") as f:
trie = Trie.load_from_dict(pickle.load(f))
# Load the model
model = GENRE.from_pretrained("models/corpus_brain").eval()
# Generate Wikipedia titles
model.sample(
sentences=["Einstein was a German physicist."],
prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
[[{'text': 'Albert Einstein', 'score': tensor(-0.0708)},
{'text': 'Werner Bruschke', 'score': tensor(-1.5357)},
{'text': 'Werner von Habsburg', 'score': tensor(-1.8696)},
{'text': 'Werner von Moltke', 'score': tensor(-2.2318)},
{'text': 'Werner von Eichstedt', 'score': tensor(-3.0177)}]]
This project is under Apache License 2.0.
If you find our work useful, please consider citing our paper:
@article{chen2022corpusbrain,
title={CorpusBrain: Pre-train a Generative Retrieval Model for Knowledge-Intensive Language Tasks},
author={Chen, Jiangui and Zhang, Ruqing and Guo, Jiafeng and Liu, Yiqun and Fan, Yixing and Cheng, Xueqi},
journal={arXiv preprint arXiv:2208.07652},
year={2022}
}