You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importtorchfromtransformersimportAutoModelForCausalLM, AutoTokenizerfromtransformers_cfg.grammar_utilsimportIncrementalGrammarConstraintfromtransformers_cfg.generation.logits_processimportGrammarConstrainedLogitsProcessorif__name__=="__main__":
device=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
print(f"Using device: {device}")
model_id="mistralai/Mistral-7B-v0.1"# Load model and tokenizertokenizer=AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token=tokenizer.eos_tokenmodel=AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined devicemodel.generation_config.pad_token_id=model.generation_config.eos_token_idmodel=AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined devicemodel.generation_config.pad_token_id=model.generation_config.eos_token_idgrammar_str=""" # Grammar for subset of JSON # String doesn't support unicode and escape yet # If you don't need to generate unicode and escape, you can use this grammar # We are working to support unicode and escape root ::= object object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" value ::= object | array | string | number | ("true" | "false" | "null") ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" [ \t!#-\[\]-~]* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws ws ::= ([ \t\n] ws)? """grammar=IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor=GrammarConstrainedLogitsProcessor(grammar)
# Generateprefix1="This is a valid json string for http request:"prefix2="This is a valid json string for shopping cart:"input_ids=tokenizer(
[prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"].to(
device
) # Move input_ids to the same device as modeloutput=model.generate(
input_ids,
do_sample=False,
max_new_tokens=60,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)
# decode outputgenerations=tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)
""" 'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }} 'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 } """
The text was updated successfully, but these errors were encountered:
Saibo-creator
changed the title
Grammar should be put as string in the example scripts to help run directly
Refactor Request: Grammar should be put as string in the example scripts to help run directly
Jun 4, 2024
The text was updated successfully, but these errors were encountered: