Skip to content

Commit

Permalink
#4003: debugging whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 20, 2024
1 parent a974de4 commit 0202082
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 229 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,12 @@ def split_query_key_value_and_split_heads(config, fused_qkv):

query_states = torch.reshape(query_states, shape=(batch_size, seq_length, num_heads, head_size))
query_states = torch.permute(query_states, (0, 2, 1, 3))
query_states = query_states.contiguous()

key_states = torch.reshape(key_states, shape=(batch_size, seq_length, num_heads, head_size))
key_states = torch.permute(key_states, (0, 2, 1, 3))
key_states = key_states.contiguous()

value_states = torch.reshape(value_states, shape=(batch_size, seq_length, num_heads, head_size))
value_states = torch.permute(value_states, (0, 2, 1, 3))
value_states = value_states.contiguous()

return query_states, key_states, value_states

Expand All @@ -84,7 +81,6 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *
query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias
query_states = torch.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size))
query_states = torch.permute(query_states, (0, 2, 1, 3))
query_states = query_states.contiguous()
key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters)
else:
query_states, key_states, value_states = calculate_query_key_values(
Expand All @@ -93,11 +89,11 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *
query_states *= scaling

proj_shape = (bsz * config.encoder_attention_heads, -1, head_size)
query_states = torch.reshape(query_states, shape=proj_shape).contiguous()
key_states = torch.reshape(key_states, shape=proj_shape).contiguous()
value_states = torch.reshape(value_states, shape=proj_shape).contiguous()
query_states = torch.reshape(query_states, shape=proj_shape)
key_states = torch.reshape(key_states, shape=proj_shape)
value_states = torch.reshape(value_states, shape=proj_shape)

attn_weights = query_states @ torch.permute(key_states, (0, 2, 1)).contiguous()
attn_weights = query_states @ torch.permute(key_states, (0, 2, 1))
if attention_mask is not None:
bsz, _, tgt_len, src_len = attention_mask.size()
attn_weights = (
Expand All @@ -118,6 +114,7 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *

def encoder_layer(config, hidden_states, *, parameters):
residual = hidden_states

hidden_states = F.layer_norm(
hidden_states,
(config.d_model,),
Expand Down Expand Up @@ -383,8 +380,8 @@ def preprocess_decoder_inputs(input_ids, attention_mask, *, parameters):
input_ids = torch.reshape(input_ids, (-1, input_shape[-1]))
inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight)
attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds)
positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]

positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]
decoder_hidden_states = inputs_embeds + positions

return decoder_hidden_states, attention_mask
Expand Down Expand Up @@ -433,10 +430,10 @@ def custom_preprocessor(torch_model, name):
parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}}
preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0)
preprocessed_bias = torch.cat([torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0)
parameters["key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T
parameters["key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16)
parameters["q_proj"]["weight"] = torch_model.q_proj.weight.type(torch.bfloat16).T
parameters["q_proj"]["bias"] = torch_model.q_proj.bias.type(dtype=torch.bfloat16)
parameters["key_value"]["weight"] = preprocessed_weight.T.contiguous()
parameters["key_value"]["bias"] = preprocessed_bias
parameters["q_proj"]["weight"] = torch_model.q_proj.weight.T.contiguous()
parameters["q_proj"]["bias"] = torch_model.q_proj.bias
else:
parameters = {"query_key_value": {}, "out_proj": {}}
preprocessed_weight = torch.cat(
Expand All @@ -445,11 +442,11 @@ def custom_preprocessor(torch_model, name):
preprocessed_bias = torch.cat(
[torch_model.q_proj.bias, torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0
)
parameters["query_key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T
parameters["query_key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16)
parameters["query_key_value"]["weight"] = preprocessed_weight.T.contiguous()
parameters["query_key_value"]["bias"] = preprocessed_bias

parameters["out_proj"]["weight"] = torch_model.out_proj.weight.type(torch.bfloat16).T
parameters["out_proj"]["bias"] = torch_model.out_proj.bias.type(dtype=torch.bfloat16)
parameters["out_proj"]["weight"] = torch_model.out_proj.weight.T.contiguous()
parameters["out_proj"]["bias"] = torch_model.out_proj.bias
return parameters


Expand All @@ -466,7 +463,7 @@ def custom_preprocessor(torch_model, name):
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.type(torch.bfloat16)
input_features = inputs.input_features
decoder_input_ids = torch.ones(1, 1).type(torch.int32) * model.config.decoder_start_token_id

model_graph = draw_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1):
)
)
input_embeds = input_embeds.permute(0, 2, 1)
input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16)
input_embeds = ttnn.to_device(input_embeds, device)
input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16, device=device)

return input_embeds

Expand All @@ -318,10 +317,8 @@ def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters,
positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]
decoder_hidden_states = inputs_embeds + positions

decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16)
decoder_hidden_states = ttnn.to_device(decoder_hidden_states, device)
attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16)
attention_mask = ttnn.to_device(attention_mask, device)
decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16, device=device)
attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, device=device)

return decoder_hidden_states, attention_mask

Expand Down
Loading

0 comments on commit 0202082

Please sign in to comment.