Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
QianRuan committed Nov 4, 2024
1 parent 0242461 commit 09cfd6e
Show file tree
Hide file tree
Showing 33 changed files with 3,125 additions and 90 deletions.
210 changes: 120 additions & 90 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,115 +1,145 @@
<p align="center">
<img src='logo.png' width='200'>
</p>
# Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions
This is the official code repository for the paper "Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions", presented at EMNLP 2024 main conference. It contains the scripts for the fine-tuning approaches outlined in the paper.

# llm_classifier
[![Arxiv](https://img.shields.io/badge/Arxiv-YYMM.NNNNN-red?style=flat-square&logo=arxiv&logoColor=white)](https://put-here-your-paper.com)
[![License](https://img.shields.io/github/license/UKPLab/llm_classifier)](https://opensource.org/licenses/Apache-2.0)
[![Python Versions](https://img.shields.io/badge/Python-3.9-blue.svg?style=flat&logo=python&logoColor=white)](https://www.python.org/)
[![CI](https://github.com/UKPLab/llm_classifier/actions/workflows/main.yml/badge.svg)](https://github.com/UKPLab/llm_classifier/actions/workflows/main.yml)
Please find the paper [here](https://arxiv.org/abs/2410.02028), and star the repository to stay updated with the latest information.

This is the official template for new Python projects at UKP Lab. It was adapted for the needs of UKP Lab from the excellent [python-project-template](https://github.com/rochacbruno/python-project-template/) by [rochacbruno](https://github.com/rochacbruno).
In case of questions please contact [Qian Ruan](mailto:ruan@ukp.tu-darmstadt.de).

It should help you start your project and give you continuous status updates on the development through [GitHub Actions](https://docs.github.com/en/actions).
## Abstract
Classification is a core NLP task architecture with many potential applications. While large language models (LLMs) have brought substantial advancements in text generation, their potential for enhancing classification tasks remains underexplored. To address this gap, we propose a framework for thoroughly investigating fine-tuning LLMs for classification, including both generation- and encoding-based approaches. We instantiate this framework in edit intent classification (EIC), a challenging and underexplored classification task. Our extensive experiments and systematic comparisons with various training approaches and a representative selection of LLMs yield new insights into their application for EIC. We investigate the generalizability of these findings on five further classification tasks. To demonstrate the proposed methods and address the data shortage for empirical edit analysis, we use our bestperforming EIC model to create Re3-Sci2.0, a new large-scale dataset of 1,780 scientific document revisions with over 94k labeled edits. The quality of the dataset is assessed through human evaluation. The new dataset enables an in-depth empirical study of human editing behavior in academic writing.
![](/resource/overview.pdf)

> **Abstract:** The study of natural language processing (NLP) has gained increasing importance in recent years, with applications ranging from machine translation to sentiment analysis. Properly managing Python projects in this domain is of paramount importance to ensure reproducibility and facilitate collaboration. The template provides a structured starting point for projects and offers continuous status updates on development through GitHub Actions. Key features include a basic setup.py file for installation, packaging, and distribution, documentation structure using mkdocs, testing structure using pytest, code linting with pylint, and entry points for executing the program with basic CLI argument parsing. Additionally, the template incorporates continuous integration using GitHub Actions with jobs to check, lint, and test the project, ensuring robustness and reliability throughout the development process.
*Figure 1. In this work, we (1). present a general framework to explore the classification capabilities of LLMs, conducting extensive experiments and systematic comparisons on the EIC task; (2). use the best model to
create the Re3-Sci2.0 dataset, which comprises 1,780 scientific document revisions (a-b), associated reviews (c, d), and 94,482 edits annotated with action and intent labels (e, f), spanning various scholarly domains;
(3). provide a first in-depth empirical analysis of human editing behavior using this new dataset.*

Contact person: [Federico Tiblias](mailto:federico.tiblias@tu-darmstadt.de)
## Approaches
![](/resource/approaches.pdf)

[UKP Lab](https://www.ukp.tu-darmstadt.de/) | [TU Darmstadt](https://www.tu-darmstadt.de/
)
*Figure 2. Proposed approaches with a systematic investigation of the key components: input types (red), language models (green), and transformation functions (yellow). See §3 and §4 of the paper for details.*

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.


## Getting Started

> **DO NOT CLONE OR FORK**
If you want to set up this template:

1. Request a repository on UKP Lab's GitHub by following the standard procedure on the wiki. It will install the template directly. Alternatively, set it up in your personal GitHub account by clicking **[Use this template](https://github.com/rochacbruno/python-project-template/generate)**.
2. Wait until the first run of CI finishes. Github Actions will commit to your new repo with a "✅ Ready to clone and code" message.
3. Delete optional files:
- If you don't need automatic documentation generation, you can delete folder `docs`, file `.github\workflows\docs.yml` and `mkdocs.yml`
- If you don't want automatic testing, you can delete folder `tests` and file `.github\workflows\tests.yml`
4. Prepare a virtual environment:
## Quickstart
1. Download the project from github.
```bash
python -m venv .venv
source .venv/bin/activate
pip install .
pip install -r requirements-dev.txt # Only needed for development
git clone https://github.com/UKPLab/llm_classifier
```
5. Adapt anything else (for example this file) to your project.

6. Read the file [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) for more information about development.

## Usage

### Using the classes

To import classes/methods of `llm_classifier` from inside the package itself you can use relative imports:

```py
from .base import BaseClass # Notice how I omit the package name

BaseClass().something()
2. Setup environment
```bash
python -m venv .llm_classifier
source ./.llm_classifier/bin/activate
pip install -r requirements.txt
```


### Fine-tuining LLMs
Check the 'finetune_EIC_\<X\>.py' scripts for the complete workflows with each approach: Gen, SeqC, XNet and SNet. You can customize the arguments within \<settings\> and \</settings\>. Refer to the paper for more details.

For example, fine-tune LLM with the SeqC approach:

1. Basic Settings

```python
############################################################################
# basic settings
# <settings>
task_name ='edit_intent_classification'
method = 'finetuning_llm_seqc' # select an approach from ['finetuning_llm_gen','finetuning_llm_seqc', 'finetuning_llm_snet', 'finetuning_llm_xnet']
train_type ='train' # name of the training data in data/Re3-Sci/tasks/edit_intent_classification
val_type = 'val' # name of the validation data in data/Re3-Sci/tasks/edit_intent_classification
test_type = 'test' # name of the test data in data/Re3-Sci/tasks/edit_intent_classification
# </settings>
############################################################################
```
2. Load Data

To import classes/methods from outside the package (e.g. when you want to use the package in some other project) you can instead refer to the package name:

```py
from llm_classifier import BaseClass # Notice how I omit the file name
from llm_classifier.subpackage import SubPackageClass # Here it's necessary because it's a subpackage

BaseClass().something()
SubPackageClass().something()
```python
from tasks.task_data_loader import TaskDataLoader
task_data_loader = TaskDataLoader(task_name=task_name, train_type=train_type, val_type=val_type, test_type=test_type)
train_ds, val_ds, test_ds= task_data_loader.load_data()
labels, label2id, id2label = task_data_loader.get_labels()
```

### Using scripts

This is how you can use `llm_classifier` from command line:

```bash
$ python -m llm_classifier
3. Load Model

```python
# load model from path
# <settings>
model_path = 'path/to/model'
emb_type = None # transformation function for xnet and snet approaches, select from [''diff', diffABS', 'n-diffABS', 'n-o', 'n-diffABS-o'], None for SeqC and Gen
input_type='text_st_on' #input type for the model, select from ['text_nl_on', 'text_st_on', 'inst_text_st_on', 'inst_text_nl_on'] for natural language input, structured input, instruction + structured input, instruction + natural language input, respectively
# </settings>
from tasks.task_model_loader import TaskModelLoader
model_loader = TaskModelLoader(task_name=task_name, method=method).model_loader
model, tokenizer = model_loader.load_model_from_path(model_path, labels=labels, label2id=label2id, id2label=id2label, emb_type=emb_type, input_type=input_type)
```
4. Preprocess dataset

```python
# <settings>
max_length = 1024
# </settings>
from tasks.task_data_preprocessor import TaskDataPreprocessor
data_preprocessor = TaskDataPreprocessor(task_name=task_name, method=method).data_preprocessor
train_ds = data_preprocessor.preprocess_data(train_ds, label2id, tokenizer, max_length=max_length, input_type=input_type)
val_ds = data_preprocessor.preprocess_data(val_ds, label2id, tokenizer, max_length=max_length, input_type=input_type)
test_ds = data_preprocessor.preprocess_data(test_ds, label2id, tokenizer, max_length=max_length, input_type=input_type)
```
5. Fine-tune model

```python
# fine-tune model
# <settings>
lora_r = 128 # LoRA rank parameter
lora_alpha = 128 # Alpha parameter for LoRA scaling
lora_dropout = 0.1 # Dropout probability for LoRA layers
learning_rate = 2e-4 # Learning rate
per_device_train_batch_size = 32 # Batch size per GPU for training
train_epochs = 2 # Number of epochs to train
recreate_dir = True # Create a directory for the model
# </settings>
# create model dir to save the fine-tuned model
from finetune_EIC_SeqC import create_model_dir
output_dir = create_model_dir(task_name, method, model_path, lora_r, lora_alpha, lora_dropout, learning_rate,
per_device_train_batch_size, train_epochs, train_type, test_type,
max_length, emb_type, input_type, recreate_dir=recreate_dir)
# fine-tune
from tasks.task_model_finetuner import TaskModelFinetuner
model_finetuner = TaskModelFinetuner(task_name=task_name, method=method).model_finetuner
model_finetuner.fine_tune(model, tokenizer, train_ds = train_ds , val_ds = val_ds, lora_r = lora_r, lora_alpha = lora_alpha, lora_dropout = lora_dropout,
learning_rate = learning_rate, per_device_train_batch_size = per_device_train_batch_size, train_epochs = train_epochs, output_dir = output_dir)
```
6. Evaluate

```python
# fine-tune model
# evaluate the fine-tuned model
from tasks.task_evaluater import TaskEvaluater
evaluater = TaskEvaluater(task_name=task_name, method=method).evaluater
evaluater.evaluate(test_ds, model_dir=output_dir, labels=labels, label2id=label2id, id2label=id2label, emb_type=emb_type, input_type=input_type, response_key=response_key)
```

### Expected results

After running the experiments, you should expect the following results:

(Feel free to describe your expected results here...)

### Parameter description

* `x, --xxxx`: This parameter does something nice

* ...

* `z, --zzzz`: This parameter does something even nicer

## Development

Read the FAQs in [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) to learn more about how this template works and where you should put your classes & methods. Make sure you've correctly installed `requirements-dev.txt` dependencies

## Cite
## Citation

Please use the following citation:

```
@InProceedings{smith:20xx:CONFERENCE_TITLE,
author = {Smith, John},
title = {My Paper Title},
booktitle = {Proceedings of the 20XX Conference on XXXX},
month = mmm,
year = {20xx},
address = {Gotham City, USA},
publisher = {Association for XXX},
pages = {XXXX--XXXX},
url = {http://xxxx.xxx}
@misc{ruan2024llmclassifier,
title={Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions},
author={Qian Ruan and Ilia Kuznetsov and Iryna Gurevych},
year={2024},
eprint={2410.02028},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.02028},
}
```

## Disclaimer
This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

<https://intertext.ukp-lab.de/>

<https://www.ukp.tu-darmstadt.de>

> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.
<https://www.tu-darmstadt.de>
Empty file added tasks/__init__.py
Empty file.
Empty file.
11 changes: 11 additions & 0 deletions tasks/edit_intent_classification/finetuning_llm_gen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .model_loader import ModelLoader
from .data_preprocessor import DataPreprocessor
from .model_finetuner import ModelFinetuner
from .evaluater import Evaluater

__all__ = [
"ModelLoader",
"DataPreprocessor"
"ModelFinetuner"
"Evaluater"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Initialize static strings for the prompt template
# natural language input (nl)
INSTRUCTION_KEY = "### Instruction:"
INSTRUCTION_KEY_END = ''
INPUT_KEY = "INPUT:"
INPUT_KEY_END = ''
NEW_START = 'NEW:'
NEW_END = ''
OLD_START = 'OLD:'
OLD_END = ''
RESPONSE_KEY = 'RESPONSE:'
END_KEY = '### End'


#structured input (st)
INSTRUCTION_KEY_ST = "<instruction>"
INSTRUCTION_KEY_END_ST = '</instruction>'
INPUT_KEY_ST = '<input>'
INPUT_KEY_END_ST = '</input>'
NEW_START_ST = '<new>'
NEW_END_ST = '</new>'
OLD_START_ST = '<old>'
OLD_END_ST = '</old>'
RESPONSE_KEY_ST = "<response>"
END_KEY_ST = "</response>"


TASK_PROMPT = "Classify the intent of the following sentence edit. The possible labels are: Grammar, Clarity, Fact/Evidence, Claim, Other. "

PROMPT_ST_DIC = {'nl': [INSTRUCTION_KEY,INSTRUCTION_KEY_END, INPUT_KEY, INPUT_KEY_END, OLD_START,OLD_END, NEW_START, NEW_END, RESPONSE_KEY, END_KEY],
'st': [INSTRUCTION_KEY_ST,INSTRUCTION_KEY_END_ST, INPUT_KEY_ST, INPUT_KEY_END_ST, OLD_START_ST,OLD_END_ST, NEW_START_ST, NEW_END_ST, RESPONSE_KEY_ST, END_KEY_ST]}

class DataPreprocessor:
def __init__(self) -> None:
print('Preprocessing the data...Gen')

def preprocess_data(self, dataset, max_length=1024, input_type='text_st_on', is_train:bool=True):
"""
:param max_length (int): Maximum number of tokens to emit from the tokenizer
:param input_type (str): Type of input text
"""
self.prompt_st_type = input_type.split('_')[-2]
instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]

# Add prompt to each sample
print("Preprocessing dataset...")
if is_train:
dataset = dataset.map(self.create_prompt_formats_train, keep_in_memory=True)
else:
dataset = dataset.map(self.create_prompt_formats_test, keep_in_memory=True)

# Shuffle dataset
seed = 42
dataset = dataset.shuffle(seed = seed)
return dataset, response_key

def create_prompt_formats_train(self, sample):
"""
Creates a formatted prompt template for a prompt in the dataset
:param sample: sample from the dataset
"""
instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]
task_prompt = TASK_PROMPT
# Combine a prompt with the static strings
instruction = f"{instruction_key} {task_prompt} {instruction_key_end}"

text_src = sample['text_src'] if sample['text_src'] is not None else ''
text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else ''
input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}"
response = f"{response_key}{sample['label']}"
end = f"{end_key}"
# Create a list of prompt template elements
parts = [part for part in [instruction, input_context, response, end] if part]
# Join prompt template elements into a single string to create the prompt template
formatted_prompt = "\n".join(parts)
# Store the formatted prompt template in a new key "text"
sample["text"] = formatted_prompt
return sample

def create_prompt_formats_test(self, sample):
"""
Creates a formatted prompt template for a prompt in the dataset
:param sample: sample from the dataset
"""
instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type]
task_prompt = TASK_PROMPT
instruction = f"{instruction_key} {task_prompt} {instruction_key_end}"
text_src = sample['text_src'] if sample['text_src'] is not None else ''
text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else ''
input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}"
response = f"{response_key}"
parts = [part for part in [instruction, input_context, response] if part]
formatted_prompt = "\n".join(parts)
sample["text"] = formatted_prompt
return sample

Loading

0 comments on commit 09cfd6e

Please sign in to comment.