Skip to content

Commit

Permalink
Merge pull request #81 from codelion/fix-litellm-wrapper-for-claude
Browse files Browse the repository at this point in the history
Fix litellm wrapper for claude
  • Loading branch information
codelion authored Oct 27, 2024
2 parents c74902d + 0ebae20 commit a8d56ff
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
27 changes: 20 additions & 7 deletions optillm/entropy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,30 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis)
return entropy, varentropy

def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]:
attention_probs = F.softmax(attention_scores, dim=-1)
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
attention_probs = attention_weights

# Calculate entropy
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
attn_varentropy = torch.var(attn_entropy, dim=-1)

attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
# Calculate variance of entropy with unbiased=False to avoid df issues
# Also add a check for singleton dimensions
if attn_entropy.size(-1) > 1:
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
else:
attn_varentropy = torch.zeros_like(attn_entropy)

attn_varentropy = torch.where(torch.isnan(attn_varentropy),
torch.zeros_like(attn_varentropy),
attn_varentropy)

# Rest remains the same
mean_attention = torch.mean(attention_probs, dim=1)
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))

interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))


attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))

return {
"attn_entropy": torch.mean(attn_entropy),
"attn_varentropy": torch.mean(attn_varentropy),
Expand Down
5 changes: 4 additions & 1 deletion optillm/litellm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class Chat:
class Completions:
@staticmethod
def create(model: str, messages: List[Dict[str, str]], **kwargs):
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
if model.startswith("gemini"):
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
else:
response = completion(model=model, messages=messages, **kwargs)
# Convert LiteLLM response to match OpenAI response structure
return response

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.6",
version="0.0.7",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down

0 comments on commit a8d56ff

Please sign in to comment.