Skip to content

Commit

Permalink
Feature(MInference): add unittest (#31)
Browse files Browse the repository at this point in the history
Co-authored-by: Yucheng Li <liyucheng09@gmail.com>
Co-authored-by: Chengruidong Zhang <chengzhang@microsoft.com>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent d2d8747 commit 1dd709f
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ style:
isort -rc $(CHECK_DIRS)

test:
@${PYTHON} -m pytest -n auto --dist=loadfile -s -v ./tests/
@${PYTHON} -m pytest -n 1 --dist=loadfile -s -v ./tests/
10 changes: 2 additions & 8 deletions experiments/benchmarks/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@
# Licensed under The MIT License [see LICENSE for details]

import argparse
import sys
import time
from collections import defaultdict

import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
LlamaForCausalLM,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from minference import MInference

Expand Down Expand Up @@ -108,6 +101,7 @@ def run_benchmark(model_name: str):
"hf",
"streaming",
"minference",
"minference_with_dense",
"inf_llm",
],
)
Expand Down
70 changes: 70 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import unittest

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from minference import MInference


class MInferenceE2ETester(unittest.TestCase):
"""
End2end Test for MInference
"""

def __init__(self, *args, **kwargs):
super(MInferenceE2ETester, self).__init__(*args, **kwargs)

# paramaters
model_name = "gradientai/Llama-3-8B-Instruct-262k"
trust_remote_code = False
attn_type = "minference"
kv_cache_cpu = True
self.attn_type = attn_type

# init model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)

model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=trust_remote_code,
)
attn_kwargs = {}
minference_patch = MInference(
attn_type,
model_name,
kv_cache_cpu=kv_cache_cpu,
attn_kwargs=attn_kwargs,
)
self.model = minference_patch.patch_model(model)

self.prompt_complex = open("./prompt_hardest.txt").read()

def test_general_minference(self):
def test_different_context_windows(seq_len: int):
input_ids = self.tokenizer(self.prompt_complex)["input_ids"]
n = len(input_ids)
b = seq_len // n + 1

new_input_ids = (input_ids * b)[:seq_len]
prompt = self.tokenizer.decode(new_input_ids)
data = self.tokenizer(prompt, return_tensors="pt")
input_ids = data["input_ids"].cuda()
attention_mask = data["attention_mask"].cuda()

with torch.no_grad():
if self.attn_type != "inf_llm":
self.model(input_ids, attention_mask, use_cache=False)
else:
self.model.generate(
input_ids, generation_config=GenerationConfig(max_new_tokens=1)
)

test_different_context_windows(100000)
test_different_context_windows(1000000)

0 comments on commit 1dd709f

Please sign in to comment.