Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed May 3, 2024
2 parents 4e9957f + dd29d49 commit 9645019
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
7 changes: 6 additions & 1 deletion scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def parse_and_validate_args():
action="store_true",
)
parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
parsed_args = parser.parse_args()

print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}")
Expand Down Expand Up @@ -441,7 +446,7 @@ def export_experiment_info(

if __name__ == "__main__":
args = parse_and_validate_args()
tuned_model = TunedCausalLM.load(args.model)
tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn)
eval_data = datasets.load_dataset(
"json", data_files=args.data_path, split=args.split
)
Expand Down
27 changes: 24 additions & 3 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def __init__(self, model, tokenizer, device):

@classmethod
def load(
cls, checkpoint_path: str, base_model_name_or_path: str = None
cls,
checkpoint_path: str,
base_model_name_or_path: str = None,
use_flash_attn: bool = False,
) -> "TunedCausalLM":
"""Loads an instance of this model.
Expand All @@ -152,6 +155,8 @@ def load(
adapter_config.json.
base_model_name_or_path: str [Default: None]
Override for the base model to be used.
use_flash_attn: bool [Default: False]
Whether to load the model using flash attention.
By default, the paths for the base model and tokenizer are contained within the adapter
config of the tuned model. Note that in this context, a path may refer to a model to be
Expand All @@ -173,14 +178,24 @@ def load(
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoPeftModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
Expand Down Expand Up @@ -246,6 +261,11 @@ def main():
type=int,
default=20,
)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", help="Text to run inference on")
group.add_argument(
Expand All @@ -261,6 +281,7 @@ def main():
loaded_model = TunedCausalLM.load(
checkpoint_path=args.model,
base_model_name_or_path=args.base_model_name_or_path,
use_flash_attn=args.use_flash_attn,
)

# Run inference on the text; if multiple were provided, process them all
Expand Down

0 comments on commit 9645019

Please sign in to comment.