Skip to content

Commit

Permalink
Merge pull request #80 from OVINC-CN/feat_tools
Browse files Browse the repository at this point in the history
feat(tools): remove tools usage
  • Loading branch information
OrenZhang authored Oct 20, 2024
2 parents 438feed + 3a6af7a commit 35a780f
Show file tree
Hide file tree
Showing 16 changed files with 24 additions and 446 deletions.
6 changes: 0 additions & 6 deletions apps/chat/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from apps.chat.client.doubao import DoubaoClient
from apps.chat.client.gemini import GeminiClient
from apps.chat.client.hunyuan import HunYuanClient, HunYuanVisionClient
from apps.chat.client.kimi import KimiClient
from apps.chat.client.midjourney import MidjourneyClient
from apps.chat.client.openai import OpenAIClient, OpenAIVisionClient
from apps.chat.client.qianfan import QianfanClient

__all__ = (
"GeminiClient",
"OpenAIClient",
"OpenAIVisionClient",
"HunYuanClient",
"HunYuanVisionClient",
"QianfanClient",
"KimiClient",
"DoubaoClient",
"MidjourneyClient",
)
40 changes: 2 additions & 38 deletions apps/chat/client/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import abc
import datetime
import json
from typing import List

from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404

from apps.chat.constants import OpenAIRole
from apps.chat.models import AIModel, ChatLog, Message, ToolParams
from apps.chat.tools import TOOLS
from apps.chat.models import AIModel, ChatLog, Message

USER_MODEL = get_user_model()

Expand All @@ -21,9 +18,7 @@ class BaseClient:
"""

# pylint: disable=R0913,R0917
def __init__(
self, user: str, model: str, messages: List[Message], temperature: float, top_p: float, tools: List[dict]
):
def __init__(self, user: str, model: str, messages: List[Message], temperature: float, top_p: float):
self.user: USER_MODEL = get_object_or_404(USER_MODEL, username=user)
self.model: str = model
self.model_inst: AIModel = AIModel.objects.get(model=model, is_enabled=True)
Expand All @@ -35,7 +30,6 @@ def __init__(
]
self.temperature: float = temperature
self.top_p: float = top_p
self.tools: List[dict] = (tools or None) if settings.CHATGPT_TOOLS_ENABLED else None
self.finished_at: int = int()
self.log = ChatLog.objects.create(
user=self.user,
Expand All @@ -58,33 +52,3 @@ def record(self, *args, **kwargs) -> None:
"""

raise NotImplementedError()


class OpenAIToolMixin:
"""
OpenAI Tool Mixin
"""

async def use_tool(self, tool_params: ToolParams, *args, **kwargs) -> any:
self.messages.append(
{
"role": OpenAIRole.ASSISTANT.value,
"tool_calls": [
{
"id": tool_params.id,
"type": tool_params.type,
"function": {
"arguments": tool_params.arguments,
"name": tool_params.name,
},
}
],
"content": "",
}
)
result = await TOOLS[tool_params.name](**json.loads(tool_params.arguments)).run()

self.messages.append({"role": OpenAIRole.TOOL, "content": result, "tool_call_id": tool_params.id})

async for i in self.chat(*args, **kwargs):
yield i
74 changes: 0 additions & 74 deletions apps/chat/client/doubao.py

This file was deleted.

6 changes: 2 additions & 4 deletions apps/chat/client/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ class GeminiClient(BaseClient):
"""

# pylint: disable=R0913,R0917
def __init__(
self, user: str, model: str, messages: List[Message], temperature: float, top_p: float, tools: List[dict]
):
super().__init__(user=user, model=model, messages=messages, temperature=temperature, top_p=top_p, tools=tools)
def __init__(self, user: str, model: str, messages: List[Message], temperature: float, top_p: float):
super().__init__(user=user, model=model, messages=messages, temperature=temperature, top_p=top_p)
genai.configure(api_key=settings.GEMINI_API_KEY)
self.genai_model = genai.GenerativeModel(self.model)

Expand Down
93 changes: 0 additions & 93 deletions apps/chat/client/kimi.py

This file was deleted.

50 changes: 19 additions & 31 deletions apps/chat/client/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@
from channels.db import database_sync_to_async
from django.conf import settings
from django.utils import timezone
from django.utils.translation import gettext
from httpx import Client
from openai import OpenAI
from openai.types import ImagesResponse
from openai.types.chat import ChatCompletionChunk
from ovinc_client.core.logger import logger
from rest_framework import status

from apps.chat.client.base import BaseClient, OpenAIToolMixin
from apps.chat.client.base import BaseClient
from apps.chat.exceptions import GenerateFailed, LoadImageFailed
from apps.chat.models import ToolParams
from apps.chat.tools import TOOLS
from apps.cos.client import COSClient


Expand All @@ -39,7 +36,7 @@ def build_client(self, api_version: str) -> OpenAI:
)


class OpenAIClient(OpenAIMixin, OpenAIToolMixin, BaseClient):
class OpenAIClient(OpenAIMixin, BaseClient):
"""
OpenAI Client
"""
Expand All @@ -53,50 +50,41 @@ async def chat(self, *args, **kwargs) -> any:
temperature=self.temperature,
top_p=self.top_p,
stream=True,
tools=self.tools,
tool_choice="auto" if self.tools else None,
)
except Exception as err: # pylint: disable=W0718
logger.exception("[GenerateContentFailed] %s", err)
yield str(GenerateFailed())
response = []
content = ""
tool_params = ToolParams()
prompt_tokens = 0
completion_tokens = 0
# pylint: disable=E1133
for chunk in response:
self.record(response=chunk)
content += chunk.choices[0].delta.content or ""
yield chunk.choices[0].delta.content or ""
# check tool use
if chunk.choices[0].delta.tool_calls:
tool_params.arguments += chunk.choices[0].delta.tool_calls[0].function.arguments
tool_params.name = chunk.choices[0].delta.tool_calls[0].function.name or tool_params.name
tool_params.type = chunk.choices[0].delta.tool_calls[0].type or tool_params.type
tool_params.id = chunk.choices[0].delta.tool_calls[0].id or tool_params.id
# call tool
if tool_params.name:
_message = gettext("[The result is using tool %s]") % str(TOOLS[tool_params.name].name_alias)
yield _message
yield " \n \n"
async for i in self.use_tool(tool_params, *args, **kwargs):
yield i
if chunk.choices:
content += chunk.choices[0].delta.content or ""
yield chunk.choices[0].delta.content or ""
elif chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
self.finished_at = int(timezone.now().timestamp() * 1000)
await self.post_chat(content, use_tool=bool(tool_params.name))
await self.post_chat(content, prompt_tokens, completion_tokens)

# pylint: disable=W0221,R1710
def record(self, response: ChatCompletionChunk, **kwargs) -> None:
self.log.chat_id = response.id

async def post_chat(self, content: str, use_tool: bool) -> None:
async def post_chat(self, content: str, prompt_tokens: int, completion_tokens: int) -> None:
if not self.log:
return
# calculate tokens
encoding = tiktoken.encoding_for_model(self.model)
self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.messages])))
self.log.completion_tokens = len(encoding.encode(content))
if use_tool:
self.log.prompt_tokens *= 2
self.log.completion_tokens *= 2
if prompt_tokens and completion_tokens:
self.log.prompt_tokens = prompt_tokens
self.log.completion_tokens = completion_tokens
else:
encoding = tiktoken.encoding_for_model(self.model)
self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.messages])))
self.log.completion_tokens = len(encoding.encode(content))
# calculate price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
Expand Down
Loading

0 comments on commit 35a780f

Please sign in to comment.