Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beam search doesn't work with transformers_cfg #9

Open
minniekabra opened this issue Feb 20, 2024 · 6 comments
Open

beam search doesn't work with transformers_cfg #9

minniekabra opened this issue Feb 20, 2024 · 6 comments
Labels
enhancement New feature or request

Comments

@minniekabra
Copy link

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained("gpt2")

    # Load json grammar
    with open("examples/grammars/json.ebnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json string for http request:"
    prefix2 = "This is a valid json string for shopping cart:"
    input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

    output = model.generate(
        input_ids,
        do_sample=False,
        max_length=50,
        num_beams=1, #this can't be >1 
        logits_processor=[grammar_processor],
        repetition_penalty=5.0,
        num_return_sequences=1,
    )
    
    
    
    
    
@Saibo-creator
Copy link
Collaborator

Saibo-creator commented Feb 20, 2024

Thanks for raising this issue, the support for beam search is yet in progress.

The error message is below

ValueError: All stacks are empty, so the only token accepted is EOS(2), but got 539

@Saibo-creator Saibo-creator added the enhancement New feature or request label Feb 20, 2024
@Saibo-creator
Copy link
Collaborator

Here, I describe how to integrate support for beam search with grammar-constrained decoding in case we have volunteer wants to contribute :)

At present, our library utilizes a logit_processor to influence the decoding process. This processor uses an underlying parser to determine permissible tokens at each step.

While effective for various decoding/sampling methods, it doesn't suit constrained beam search.

The incompatibility of the constrained logit processor with beam search is complex and relates to the mechanics of beam search itself. However, this detail is not central to this feature, as our focus is on employing the Constraint class from Hugging Face.

Credit goes to @chanwkimlab for developing the constrained beam search and providing a robust abstraction along with a comprehensive blog post: https://huggingface.co/blog/constrained-beam-search

The procedure involves:

  1. Creating class GrammarConstraint and conducting tests.
  2. Using GrammarConstraint instead of GrammarConstraintLogitProcessor during inference and testing.
class GrammarConstraint(Constraint):

    def __init__(self, token_ids: List[int]):
        super(Constraint, self).__init__()
        ...

    def advance(self):
           ...

    def does_advance(self, token_id: int):
           ...

    def update(self, token_id: int):
           ...

    def reset(self):
        self.completed = False
        self.fulfilled_idx = 0

    def remaining(self):
           # For grammar constrained decoding, determining the exact number of remaining tokens may be challenging, but it should not pose a significant issue. 

Here are some example implementation of Constraints in HF library: https://github.com/huggingface/transformers/blob/c60749d6a67d223d65a2fb6105c2459f3469a30d/src/transformers/generation/beam_constraints.py#L129

That's it !

@HichemAK
Copy link

Hello! Is this still an active issue, or does a workaround have been found?

I can give a shot at coding the GrammarConstraint class

@Saibo-creator
Copy link
Collaborator

Saibo-creator commented Aug 13, 2024 via email

@HichemAK
Copy link

Hello, unfortunately I couldn't make it work, this constraint feature lacks documentation and it's difficult to understand how it works behind the scenes. When coding, I tried to follow the same format as the constraints found in the transformers library.

transformers version: 4.44.0

Here is my best attempt:

from transformers.generation.beam_constraints import Constraint
from transformers_cfg.grammar_utils import IncrementalTokenRecognizer, IncrementalGrammarConstraint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class GrammarConstraint(Constraint):
    def __init__(self, token_recognizer : IncrementalTokenRecognizer):
        super(Constraint, self).__init__()
        self.token_recognizer = token_recognizer
        self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
        self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
        self.completed = False
        self.seqlen = float('inf')
        self.tokens = []

    
    @property
    def text(self):
        return self.token_recognizer.tokenizer.decode(self.tokens)
        

    def advance(self):
        # Return the next set of tokens that would be accepted by the current grammar state
        if self.completed:
            return []
        acceptance = self.valid_tokens
        return acceptance.nonzero(as_tuple=False).squeeze(-1).tolist()

    def does_advance(self, token_id: int):
        # Check if the given token_id is accepted by the current grammar state
        acceptance = self.valid_tokens
        return acceptance[token_id]

    def update(self, token_id: int):
        # Update the state with the given token_id and return the progress indicators
        if self.does_advance(token_id):
            new_state = self.token_recognizer._update_state_with_token_id(token_id, self.current_state)
            self.current_state = new_state
            
            stepped = True
            completed = not bool(new_state.stacks)  # If stacks are empty, the constraint is completed
            self.tokens.append(token_id)
            if not completed:
                self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
            reset = False
        else:
            # The token_id was not accepted, reset the state
            self.reset()
            stepped = False
            completed = False
            reset = True

        self.completed = completed
        return stepped, completed, reset

    def reset(self):
        # Reset the state of this constraint to its initialization
        self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
        self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
        self.completed = False
        self.tokens.clear()

    def remaining(self):
        # Return the number of remaining steps; this is more complex for a grammar constraint
        # and might not be easily quantifiable. For simplicity, we return 1 if not completed.
        return 0 if self.completed else 1

    def copy(self, stateful=False):
        # Create a new instance of this constraint
        new_constraint = GrammarConstraint(
            self.token_recognizer
        )
        if stateful:
            new_constraint.current_state = self.current_state
            new_constraint.valid_tokens = self.valid_tokens
            new_constraint.completed = self.completed
            new_constraint.tokens = self.tokens.copy()
        return new_constraint


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint

if __name__ == "__main__":
    # Detect if GPU is available, otherwise use CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model_id = "gpt2"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    # Load json grammar
    with open("tuples.ebnf", "r") as file:
        grammar_str = file.read()

    token_recognizer = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True)
    grammar = GrammarConstraint(token_recognizer)

    model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)  # Load model to defined device
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    

    # Generate
    prefix1 = "Tuples:"
    input_ids = tokenizer([prefix1], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"].to(device)
    max_new_tokens = 50
    # grammar.seqlen = max_new_tokens
    output = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        constraints=[grammar],
        num_beams=3,
        do_sample=False
    )
    # decode output
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)

Here is the content of tuples.ebnf:

root   ::= triple triple

triple ::= "[" object object object "]"

object ::= "A" | "B" | "C"

@Saibo-creator
Copy link
Collaborator

@HichemAK Thanks for your effort! After diving deeper into beam search, I found that the implementation of constrained beam search in HF is quite convoluted and too closely tied to existing constraints, making it not general enough. Trying to implement it directly is indeed not the best way. The results coding will be very ugly and ineffient.

It might be better to avoid that approach and work directly with beam search, but that would require modifying the HF codebase.

I’ve sketched out how I plan to implement this beam search. For those interested, feel free to check it out. I’ll likely start working on it myself in the next few days.

GitHub Commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants