Skip to content

Commit

Permalink
Adding support for Anthropic, Cohere, TogetherAI, Aleph Alpha, Huggin…
Browse files Browse the repository at this point in the history
…gface Inference Endpoints, etc. (#324)

* Adding support for Anthropic, Cohere, TogetherAI, Aleph Alpha, Huggingface Inference Endpoints, etc.

* updates

* updates

* update readme

* Revert binary file

---------

Co-authored-by: Graham Neubig <neubig@gmail.com>
  • Loading branch information
krrishdholakia and neubig authored Sep 6, 2023
1 parent 469a116 commit e368960
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ You can also run through the command line.
pip install prompt2model
```

Our current `Prompt2Model` implementation uses
the OpenAI API. Accordingly, you need to:
`Prompt2Model` supports various platforms such as OpenAI, Anthropic, Huggingface, etc. using [LiteLLM](https://github.com/BerriAI/litellm).

To use OpenAI, please follow these

- Sign up on the OpenAI website and obtain an
OpenAI API key.
Expand All @@ -45,6 +46,8 @@ the following command in your terminal:
export OPENAI_API_KEY=<your key>
```

[List of all supported providers](https://docs.litellm.ai/docs/providers)

You can then run

```bash
Expand Down
33 changes: 31 additions & 2 deletions prompt2model/dataset_generator/openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def __init__(
learning during generation. This allows us to achieve high-quality,
high-diversity examples later on by using a higher temperature.
"""
self.api_key: str | None = api_key if api_key else os.environ["OPENAI_API_KEY"]
self.api_key: str | None = api_key if api_key else self.validate_environment()
if self.api_key is None or self.api_key == "":
raise ValueError(
"API key must be provided or set the environment variable "
"with `export OPENAI_API_KEY=<your key>`."
"e.g. `export OPENAI_API_KEY=<your key>`."
)
if max_api_calls and max_api_calls <= 0:
raise ValueError("max_api_calls must be > 0")
Expand Down Expand Up @@ -130,6 +130,35 @@ def __init__(
self.filter_duplicated_examples = filter_duplicated_examples
self.cache_root = Path(cache_root)

def validate_environment(self):
"""Check if any of the required API keys are present in the environment.
Returns:
str or None: The API key value if found in the environment, else None.
"""
api_key = None
if "OPENAI_API_KEY" in os.environ:
api_key = os.getenv("OPENAI_API_KEY")
elif "ANTHROPIC_API_KEY" in os.environ:
api_key = os.getenv("ANTHROPIC_API_KEY")
elif "REPLICATE_API_KEY" in os.environ:
api_key = os.getenv("REPLICATE_API_KEY")
elif "AZURE_API_KEY" in os.environ:
api_key = os.getenv("AZURE_API_KEY")
elif "COHERE_API_KEY" in os.getenv("COHERE_API_KEY"):
api_key = os.getenv("COHERE_API_KEY")
elif "TOGETHERAI_API_KEY" in os.environ:
api_key = os.getenv("TOGETHERAI_API_KEY")
elif "BASETEN_API_KEY" in os.environ:
api_key = os.getenv("BASETEN_API_KEY")
elif "AI21_API_KEY" in os.environ:
api_key = os.getenv("AI21_API_KEY")
elif "OPENROUTER_API_KEY" in os.environ:
api_key = os.getenv("OPENROUTER_API_KEY")
elif "ALEPHALPHA_API_KEY" in os.environ:
api_key = os.getenv("ALEPHALPHA_API_KEY")
return api_key

def construct_prompt(
self,
instruction: str,
Expand Down
10 changes: 5 additions & 5 deletions prompt2model/utils/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import asyncio
import json
import logging
import os
import time

import aiolimiter
import openai
import openai.error
import tiktoken
from aiohttp import ClientSession
from litellm import acompletion, completion
from tqdm.asyncio import tqdm_asyncio

OPENAI_ERRORS = (
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(self, api_key: str | None, model_name: str = "gpt-3.5-turbo"):
the environment variable with `export OPENAI_API_KEY=<your key>`.
model_name: Name fo the OpenAI model to use (by default, gpt-3.5-turbo).
"""
openai.api_key = api_key if api_key else os.environ["OPENAI_API_KEY"]
if openai.api_key is None or openai.api_key == "":
self.api_key = api_key
if self.api_key is None or self.api_key == "":
raise ValueError(
"API key must be provided or set the environment variable "
"with `export OPENAI_API_KEY=<your key>`."
Expand Down Expand Up @@ -78,7 +78,7 @@ def generate_one_openai_chat_completion(
Returns:
A response object.
"""
response = openai.ChatCompletion.create(
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
Expand Down Expand Up @@ -124,7 +124,7 @@ async def _throttled_openai_chat_completion_acreate(
async with limiter:
for _ in range(3):
try:
return await openai.ChatCompletion.acreate(
return await acompletion(
model=model,
messages=messages,
temperature=temperature,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"psutil",
"protobuf==3.20.0",
"nest-asyncio",
"litellm"
]

dynamic = ["version"]
Expand Down

0 comments on commit e368960

Please sign in to comment.