Skip to content

Commit

Permalink
optimize composio and e2b extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
conneroisu committed Nov 11, 2024
1 parent ba4965c commit a7ca26c
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 304 deletions.
50 changes: 50 additions & 0 deletions agents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package groq

import (
"context"
"fmt"
"log/slog"
)

type (
// Agent is an agent.
Agent struct {
client *Client
logger *slog.Logger
runners []ToolRunner
}
// ToolRunner is an interface for a tool manager.
ToolRunner interface {
Run(
ctx context.Context,
response ChatCompletionResponse,
) ([]ChatCompletionMessage, error)
}
)

// Run runs the agent on a chat completion response.
func (a *Agent) Run(
ctx context.Context,
response ChatCompletionResponse,
) ([]ChatCompletionMessage, error) {
for _, runner := range a.runners {
messages, err := runner.Run(ctx, response)
if err == nil {
return messages, nil
}
}
return nil, fmt.Errorf("no runners found for response")
}

// NewAgent creates a new agent.
func NewAgent(
client *Client,
logger *slog.Logger,
runners ...ToolRunner,
) *Agent {
return &Agent{
client: client,
logger: logger,
runners: runners,
}
}
47 changes: 22 additions & 25 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,6 @@ const (
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
)

type (
// FinishReason is the finish reason.
//
// string
FinishReason string
)

const (
// ReasonStop is the stop finish reason for a chat completion.
ReasonStop FinishReason = "stop"
// ReasonLength is the length finish reason for a chat completion.
ReasonLength FinishReason = "length"
// ReasonFunctionCall is the function call finish reason for a chat
// completion.
ReasonFunctionCall FinishReason = "function_call"
// ReasonToolCalls is the tool calls finish reason for a chat
// completion.
ReasonToolCalls FinishReason = "tool_calls"
// ReasonContentFilter is the content filter finish reason for a chat
// completion.
ReasonContentFilter FinishReason = "content_filter"
// ReasonNull is the null finish reason for a chat completion.
ReasonNull FinishReason = "null"
)

type (
// ImageURLDetail is the detail of the image at the URL.
//
Expand Down Expand Up @@ -348,6 +323,28 @@ type (
ChatCompletionStream struct {
*streams.StreamReader[*ChatCompletionStreamResponse]
}
// FinishReason is the finish reason.
//
// string
FinishReason string
)

const (
// ReasonStop is the stop finish reason for a chat completion.
ReasonStop FinishReason = "stop"
// ReasonLength is the length finish reason for a chat completion.
ReasonLength FinishReason = "length"
// ReasonFunctionCall is the function call finish reason for a chat
// completion.
ReasonFunctionCall FinishReason = "function_call"
// ReasonToolCalls is the tool calls finish reason for a chat
// completion.
ReasonToolCalls FinishReason = "tool_calls"
// ReasonContentFilter is the content filter finish reason for a chat
// completion.
ReasonContentFilter FinishReason = "content_filter"
// ReasonNull is the null finish reason for a chat completion.
ReasonNull FinishReason = "null"
)

// MarshalJSON method implements the json.Marshaler interface.
Expand Down
63 changes: 63 additions & 0 deletions extensions/composio/composio.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package composio

import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"

"github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go/pkg/builders"
)

Expand Down Expand Up @@ -83,3 +85,64 @@ func (c *Composio) doRequest(req *http.Request, v interface{}) error {
return nil
}
}

type (
request struct {
ConnectedAccountID string `json:"connectedAccountId"`
EntityID string `json:"entityId"`
AppName string `json:"appName"`
Input map[string]any `json:"input"`
Text string `json:"text,omitempty"`
AuthConfig map[string]any `json:"authConfig,omitempty"`
}
)

// Run runs the composio client on a chat completion response.
func (c *Composio) Run(
ctx context.Context,
user ConnectedAccount,
response groq.ChatCompletionResponse,
) ([]groq.ChatCompletionMessage, error) {
var respH []groq.ChatCompletionMessage
if response.Choices[0].FinishReason != groq.ReasonFunctionCall &&
response.Choices[0].FinishReason != "tool_calls" {
return nil, fmt.Errorf("not a function call")
}
for _, toolCall := range response.Choices[0].Message.ToolCalls {
var args map[string]any
if json.Valid([]byte(toolCall.Function.Arguments)) {
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
if err != nil {
return nil, err
}
c.logger.Debug("arguments", "args", args)
}
req, err := builders.NewRequest(
ctx,
c.header,
http.MethodPost,
fmt.Sprintf("%s/v2/actions/%s/execute", c.baseURL, toolCall.Function.Name),
builders.WithBody(&request{
ConnectedAccountID: user.ID,
EntityID: "default",
AppName: toolCall.Function.Name,
Input: args,
AuthConfig: map[string]any{},
}),
)
if err != nil {
return nil, err
}
var body string
err = c.doRequest(req, &body)
if err != nil {
return nil, err
}
respH = append(respH, groq.ChatCompletionMessage{
Content: string(body),
Name: toolCall.ID,
Role: groq.ChatMessageRoleFunction,
})
}
return respH, nil
}
116 changes: 115 additions & 1 deletion extensions/composio/composio_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,115 @@
package composio
package composio_test

import (
"context"
"encoding/json"
"net/http"
"os"
"testing"

"github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go/extensions/composio"
"github.com/conneroisu/groq-go/pkg/models"
"github.com/conneroisu/groq-go/pkg/test"
"github.com/conneroisu/groq-go/pkg/tools"
"github.com/stretchr/testify/assert"
)

func TestRun(t *testing.T) {
a := assert.New(t)
ctx := context.Background()
ts := test.NewTestServer()
ts.RegisterHandler("/v1/connectedAccounts", func(w http.ResponseWriter, _ *http.Request) {
var items struct {
Items []composio.ConnectedAccount `json:"items"`
}
items.Items = append(items.Items, composio.ConnectedAccount{
IntegrationID: "INTEGRATION_ID",
ID: "ID",
MemberID: "MEMBER_ID",
ClientUniqueUserID: "CLIENT_UNIQUE_USER_ID",
Status: "STATUS",
AppUniqueID: "APP_UNIQUE_ID",
AppName: "APP_NAME",
InvocationCount: "INVOCATION_COUNT",
})
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
jsonBytes, err := json.Marshal(items)
a.NoError(err)
_, err = w.Write(jsonBytes)
a.NoError(err)
})
ts.RegisterHandler("/v2/actions/TOOL/execute", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")

Check warning on line 44 in extensions/composio/composio_test.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 26 lines is too similar to extensions/composio/auth_test.go:14
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`response1`))
a.NoError(err)
})
testS := ts.ComposioTestServer()
testS.Start()
client, err := composio.NewComposer(
test.GetTestToken(),
composio.WithLogger(test.DefaultLogger),
composio.WithBaseURL(testS.URL),
)
a.NoError(err)
ca, err := client.GetConnectedAccounts(ctx, composio.WithShowActiveOnly(true))
a.NoError(err)

Check warning on line 58 in extensions/composio/composio_test.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 12 lines is too similar to extensions/composio/tools_test.go:32
resp, err := client.Run(ctx, ca[0], groq.ChatCompletionResponse{
Choices: []groq.ChatCompletionChoice{{
Message: groq.ChatCompletionMessage{
Role: groq.ChatMessageRoleUser,
Content: "Hello!",
ToolCalls: []tools.ToolCall{{
Function: tools.FunctionCall{
Name: "TOOL",
Arguments: `{ "foo": "bar", }`,
}}}},
FinishReason: groq.ReasonFunctionCall,
}}})
a.NoError(err)
assert.Equal(t, "response1", resp[0].Content)
}

func TestUnitRun(t *testing.T) {
if !test.IsIntegrationTest() {
t.Skip()
}
a := assert.New(t)
ctx := context.Background()
key, err := test.GetAPIKey("COMPOSIO_API_KEY")
a.NoError(err)
client, err := composio.NewComposer(
key,
composio.WithLogger(test.DefaultLogger),
)
a.NoError(err)
ts, err := client.GetTools(
ctx, composio.WithApp("GITHUB"), composio.WithUseCase("StarRepo"))
a.NoError(err)

Check warning on line 90 in extensions/composio/composio_test.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 15 lines is too similar to extensions/composio/tools_test.go:50
a.NotEmpty(ts)
groqClient, err := groq.NewClient(
os.Getenv("GROQ_KEY"),
)
a.NoError(err, "NewClient error")
response, err := groqClient.CreateChatCompletion(ctx, groq.ChatCompletionRequest{
Model: models.ModelLlama3Groq8B8192ToolUsePreview,
Messages: []groq.ChatCompletionMessage{
{
Role: groq.ChatMessageRoleUser,
Content: "Star the facebookresearch/spiritlm repository on GitHub",
},
},
MaxTokens: 2000,
Tools: ts,
})
a.NoError(err)
a.NotEmpty(response.Choices[0].Message.ToolCalls)
users, err := client.GetConnectedAccounts(ctx)
a.NoError(err)
resp2, err := client.Run(ctx, users[0], response)
a.NoError(err)
a.NotEmpty(resp2)
t.Logf("%+v\n", resp2)
}
79 changes: 0 additions & 79 deletions extensions/composio/run.go

This file was deleted.

Loading

0 comments on commit a7ca26c

Please sign in to comment.