diff --git a/.github/workflows/chron-models.yaml b/.github/workflows/chron-models.yaml index f890077..578ed61 100644 --- a/.github/workflows/chron-models.yaml +++ b/.github/workflows/chron-models.yaml @@ -1,4 +1,4 @@ -name: Commit Go Mod, Go Work, and Docs +name: Commit Go Generated Content on: workflow_dispatch: @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: '1.23.1' + go-version: '1.23.2' # Step 3: Run go mod download - name: Run go mod download diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index a3bf780..b5a75a8 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -8,17 +8,13 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.23.1' + go-version: '1.23.2' - name: Check out code uses: actions/checkout@v3 - name: Install dependencies run: | go mod download - name: Run Integration tests - env: - GROQ_KEY: ${{ secrets.GROQ_KEY }} - TOOLHOUSE_API_KEY: ${{ secrets.TOOLHOUSE_API_KEY }} - E2B_API_KEY: ${{ secrets.E2B_API_KEY }} run: | go test -race -tags=integration ./... - name: Run Unit tests @@ -35,8 +31,3 @@ jobs: env: COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: goveralls -coverprofile=covprofile -service=github - # or use shogo82148/actions-goveralls - # - name: Send coverage - # uses: shogo82148/actions-goveralls@v1 - # with: - # path-to-profile: covprofile diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 44709e3..a5e38ff 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -25,7 +25,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.23.1' + go-version: '1.23.2' cache: true - name: Install requirements id: install-lint-requirements diff --git a/.github/workflows/unit.yaml b/.github/workflows/unit.yaml new file mode 100644 index 0000000..c8a947c --- /dev/null +++ b/.github/workflows/unit.yaml @@ -0,0 +1,32 @@ +name: Unit Tests +on: + workflow_dispatch: {} +jobs: + test: + name: Test with Coverage + runs-on: ubuntu-latest + steps: + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '1.23.2' + - name: Check out code + uses: actions/checkout@v3 + - name: Install dependencies + run: | + go mod download + - name: Run Integration tests + env: + GROQ_KEY: ${{ secrets.GROQ_KEY }} + TOOLHOUSE_API_KEY: ${{ secrets.TOOLHOUSE_API_KEY }} + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} + run: | + go test -race -tags=integration ./... + - name: Run Unit tests + env: + GROQ_KEY: ${{ secrets.GROQ_KEY }} + TOOLHOUSE_API_KEY: ${{ secrets.TOOLHOUSE_API_KEY }} + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} + UNIT: true + run: | + go test -race -covermode atomic -coverprofile=covprofile ./... diff --git a/README.md b/README.md index 572b111..9429ab8 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,13 @@ [![Coverage Status](https://coveralls.io/repos/github/conneroisu/groq-go/badge.svg?branch=main)](https://coveralls.io/github/conneroisu/groq-go?branch=main) [![PhormAI](https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml)](https://www.phorm.ai/query?projectId=0634251d-5a98-4c37-ac2f-385b588ce3d3) + + Powered by Groq for fast inference. + + ## Features - Supports all models from [Groq](https://wow.groq.com/) in a type-safe way. diff --git a/audio_test.go b/audio_test.go index 22c7510..60094c6 100644 --- a/audio_test.go +++ b/audio_test.go @@ -1,6 +1,3 @@ -//go:build !test -// +build !test - package groq import ( diff --git a/chat.go b/chat.go index 3e3b1b0..1a8ed46 100644 --- a/chat.go +++ b/chat.go @@ -13,6 +13,7 @@ import ( "time" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/tools" ) const ( @@ -29,7 +30,6 @@ const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" // ChatCompletionResponseFormatTypeJSONObject is the json object chat completion response format type. ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" // ChatCompletionResponseFormatTypeJSONSchema is the json schema chat completion response format type. ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" // ChatCompletionResponseFormatTypeText is the text chat completion response format type. - ToolTypeFunction ToolType = "function" // ToolTypeFunction is the function tool type. FinishReasonStop FinishReason = "stop" // FinishReasonStop is the stop finish reason. FinishReasonLength FinishReason = "length" // FinishReasonLength is the length finish reason. FinishReasonFunctionCall FinishReason = "function_call" // FinishReasonFunctionCall is the function call finish reason. @@ -69,26 +69,13 @@ type ( } // ChatCompletionMessage represents the chat completion message. ChatCompletionMessage struct { - Name string `json:"name"` // Name is the name of the chat completion message. - Role Role `json:"role"` // Role is the role of the chat completion message. - Content string `json:"content"` // Content is the content of the chat completion message. - MultiContent []ChatMessagePart `json:"-"` // MultiContent is the multi content of the chat completion message. - FunctionCall *FunctionCall `json:"function_call,omitempty"` // FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model. - ToolCalls []ToolCall `json:"tool_calls,omitempty"` // ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. - ToolCallID string `json:"tool_call_id,omitempty"` // ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. - } - // ToolCall represents a tool call. - ToolCall struct { - // Index is not nil only in chat completion chunk object - Index *int `json:"index,omitempty"` // Index is the index of the tool call. - ID string `json:"id"` // ID is the id of the tool call. - Type ToolType `json:"type"` // Type is the type of the tool call. - Function FunctionCall `json:"function"` // Function is the function of the tool call. - } - // FunctionCall represents a function call. - FunctionCall struct { - Name string `json:"name,omitempty"` // Name is the name of the function call. - Arguments string `json:"arguments,omitempty"` // Arguments is the arguments of the function call in JSON format. + Name string `json:"name"` // Name is the name of the chat completion message. + Role Role `json:"role"` // Role is the role of the chat completion message. + Content string `json:"content"` // Content is the content of the chat completion message. + MultiContent []ChatMessagePart `json:"-"` // MultiContent is the multi content of the chat completion message. + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` // FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model. + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` // ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCallID string `json:"tool_call_id,omitempty"` // ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. } // ChatCompletionResponseFormatType is the chat completion response format type. // @@ -138,29 +125,20 @@ type ( LogProbs bool `json:"logprobs,omitempty"` // LogProbs indicates whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview model. TopLogProbs int `json:"top_logprobs,omitempty"` // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. User string `json:"user,omitempty"` // User is the user of the chat completion request. - Tools []Tool `json:"tools,omitempty"` // Tools is the tools of the chat completion request. + Tools []tools.Tool `json:"tools,omitempty"` // Tools is the tools of the chat completion request. ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or an ToolChoice object. StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Options for streaming response. Only set this when you set stream: true. ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. RetryDelay time.Duration `json:"-"` // RetryDelay is the delay between retries. } - // ToolType is the tool type. - // - // string - ToolType string - // Tool represents the tool. - Tool struct { - Type ToolType `json:"type"` // Type is the type of the tool. - Function FunctionDefinition `json:"function,omitempty"` // Function is the tool's functional definition. - } - // ToolChoice represents the tool choice. - ToolChoice struct { - Type ToolType `json:"type"` // Type is the type of the tool choice. - Function ToolFunction `json:"function,omitempty"` // Function is the function of the tool choice. - } - // ToolFunction represents the tool function. - ToolFunction struct { - Name string `json:"name"` // Name is the name of the tool function. + // LogProbs is the top-level structure containing the log probability information. + LogProbs struct { + Content []struct { + Token string `json:"token"` // Token is the token of the log prob. + LogProb float64 `json:"logprob"` // LogProb is the log prob of the log prob. + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + TopLogProbs []TopLogProbs `json:"top_logprobs"` // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested top_logprobs returned. + } `json:"content"` // Content is a list of message content tokens with log probability information. } // TopLogProbs represents the top log probs. TopLogProbs struct { @@ -168,17 +146,6 @@ type ( LogProb float64 `json:"logprob"` // LogProb is the log prob of the top log probs. Bytes []byte `json:"bytes,omitempty"` // Bytes is the bytes of the top log probs. } - // LogProb represents the probability information for a token. - LogProb struct { - Token string `json:"token"` // Token is the token of the log prob. - LogProb float64 `json:"logprob"` // LogProb is the log prob of the log prob. - Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null - TopLogProbs []TopLogProbs `json:"top_logprobs"` // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested top_logprobs returned. - } - // LogProbs is the top-level structure containing the log probability information. - LogProbs struct { - Content []LogProb `json:"content"` // Content is a list of message content tokens with log probability information. - } // FinishReason is the finish reason. // string FinishReason string @@ -214,10 +181,10 @@ type ( } // ChatCompletionStreamChoiceDelta represents a response structure for chat completion API. ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` } // ChatCompletionStreamChoice represents a response structure for chat completion API. ChatCompletionStreamChoice struct { @@ -225,10 +192,6 @@ type ( Delta ChatCompletionStreamChoiceDelta `json:"delta"` FinishReason FinishReason `json:"finish_reason"` } - // PromptFilterResult represents a response structure for chat completion API. - PromptFilterResult struct { - Index int `json:"index"` - } streamer interface { ChatCompletionStreamResponse } @@ -242,14 +205,16 @@ type ( } // ChatCompletionStreamResponse represents a response structure for chat completion API. ChatCompletionStreamResponse struct { - ID string `json:"id"` // ID is the identifier for the chat completion stream response. - Object string `json:"object"` // Object is the object type of the chat completion stream response. - Created int64 `json:"created"` // Created is the creation time of the chat completion stream response. - Model ChatModel `json:"model"` // Model is the model used for the chat completion stream response. - Choices []ChatCompletionStreamChoice `json:"choices"` // Choices is the choices for the chat completion stream response. - SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint for the chat completion stream response. - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` // PromptAnnotations is the prompt annotations for the chat completion stream response. - PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` // PromptFilterResults is the prompt filter results for the chat completion stream response. + ID string `json:"id"` // ID is the identifier for the chat completion stream response. + Object string `json:"object"` // Object is the object type of the chat completion stream response. + Created int64 `json:"created"` // Created is the creation time of the chat completion stream response. + Model ChatModel `json:"model"` // Model is the model used for the chat completion stream response. + Choices []ChatCompletionStreamChoice `json:"choices"` // Choices is the choices for the chat completion stream response. + SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint for the chat completion stream response. + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` // PromptAnnotations is the prompt annotations for the chat completion stream response. + PromptFilterResults []struct { + Index int `json:"index"` + } `json:"prompt_filter_results,omitempty"` // PromptFilterResults is the prompt filter results for the chat completion stream response. // Usage is an optional field that will only be present when you set stream_options: {"include_usage": true} in your request. // // When present, it contains a null value except for the last chunk which contains the token usage statistics @@ -262,24 +227,6 @@ type ( ChatCompletionStream struct { *streamReader[ChatCompletionStreamResponse] } - // FunctionDefinition represents the function definition. - FunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters ParameterDefinition `json:"parameters"` - } - // ParameterDefinition represents the parameter definition. - ParameterDefinition struct { - Type string `json:"type"` - Properties map[string]PropertyDefinition `json:"properties"` - Required []string `json:"required"` - AdditionalProperties bool `json:"additionalProperties,omitempty"` - } - // PropertyDefinition represents the property definition. - PropertyDefinition struct { - Type string `json:"type"` - Description string `json:"description"` - } ) // MarshalJSON method implements the json.Marshaler interface. @@ -289,24 +236,24 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Name string `json:"name,omitempty"` - Role Role `json:"role"` - Content string `json:"-"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` + Role Role `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Name string `json:"name,omitempty"` - Role Role `json:"role"` - Content string `json:"content"` - MultiContent []ChatMessagePart `json:"-"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` + Role Role `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart `json:"-"` + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } @@ -318,9 +265,9 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) { Role Role `json:"role"` Content string `json:"content"` MultiContent []ChatMessagePart - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} err = json.Unmarshal(bs, &msg) if err == nil { @@ -331,10 +278,10 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) { Name string `json:"name,omitempty"` Role Role `json:"role"` Content string - MultiContent []ChatMessagePart `json:"content"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} err = json.Unmarshal(bs, &multiMsg) if err != nil { diff --git a/examples/composio-github-star/README.md b/examples/composio-github-star/README.md new file mode 100644 index 0000000..898ae8d --- /dev/null +++ b/examples/composio-github-star/README.md @@ -0,0 +1,21 @@ +# composio-github-star + +Adapted from the [quickstart](https://docs.composio.dev/introduction/intro/quickstart) guide. + +Install the `composio` CLI and login to your account (also add github to your account if you haven't already) + +```bash +pip install -U composio_core composio_openai + +composio login + +#Connect your Github so agents can use it +composio add github +``` + +Congratulations! You’ve just: + + 🔐 Authenticated your GitHub account with Composio + 🛠 Fetched GitHub tools for the llm + ⭐ Instructed the AI to star the conneroisu/groq-go repository + ✅ Successfully executed the action on GitHub diff --git a/examples/composio-github-star/main.go b/examples/composio-github-star/main.go new file mode 100644 index 0000000..19a230a --- /dev/null +++ b/examples/composio-github-star/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/extensions/composio" + "github.com/conneroisu/groq-go/pkg/test" +) + +func main() { + if err := run(context.Background()); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func run( + ctx context.Context, +) error { + key, err := test.GetAPIKey("GROQ_KEY") + if err != nil { + return err + } + client, err := groq.NewClient(key) + if err != nil { + return err + } + key, err = test.GetAPIKey("COMPOSIO_API_KEY") + if err != nil { + return err + } + comp, err := composio.NewComposer( + key, + composio.WithLogger(slog.Default()), + ) + if err != nil { + return err + } + tools, err := comp.GetTools( + ctx, + composio.WithApp("GITHUB"), + composio.WithUseCase("star-repo"), + ) + if err != nil { + return err + } + chat, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: ` +You are a github star bot. You will be given a repo name and you will star it. +Star the repo conneroisu/groq-go on GitHub. +`, + }, + }, + MaxTokens: 2000, + Tools: tools, + }) + if err != nil { + return err + } + resp, err := comp.Run(ctx, chat) + if err != nil { + return err + } + fmt.Println(resp) + return nil +} diff --git a/examples/e2b-go-project/README.md b/examples/e2b-go-project/README.md new file mode 100644 index 0000000..ecb18f5 --- /dev/null +++ b/examples/e2b-go-project/README.md @@ -0,0 +1,33 @@ +# e2b-go-project + +This is an example of using groq-go to create a simple golang project using the e2b and groq api powered by the groq-go library. + +## Usage + +Make sure you have a groq key set in the environment variable `GROQ_KEY`. +Also, make sure that you have a e2b api key set in the environment variable `E2B_API_KEY`. + +```bash +export GROQ_KEY=your-groq-key +export E2B_API_KEY=your-e2b-api-key +go run . +``` + +### System Prompt + +```txt +Given the tools given to you, create a golang project with the following files: + + +main.go +utils.go + + +The main function should call the `utils.run() error` function. + +The project should, when run, print the following: + + +Hello, World! + +``` diff --git a/examples/e2b-go-project/main.go b/examples/e2b-go-project/main.go new file mode 100644 index 0000000..9f31da2 --- /dev/null +++ b/examples/e2b-go-project/main.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/extensions/e2b" +) + +func main() { + if err := run(context.Background()); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func run( + ctx context.Context, +) error { + key := os.Getenv("GROQ_KEY") + e2bKey := os.Getenv("E2B_API_KEY") + client, err := groq.NewClient(key) + if err != nil { + return err + } + sb, err := e2b.NewSandbox( + ctx, + e2bKey, + ) + if err != nil { + return err + } + chat, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: ` + +Given the tools given to you, create a golang project with the following files: + + +main.go +utils.go + + +The main function should call the "utils.run() error" function. + +The project should, when run, print the following: + + +Hello, World! + +`, + }, + }, + MaxTokens: 2000, + Tools: sb.GetTools(), + }) + if err != nil { + return err + } + fmt.Println(chat.Choices[0].Message.Content) + return nil +} diff --git a/examples/llava-blind/main.go b/examples/llava-blind/main.go index 16d41b6..0fea70a 100644 --- a/examples/llava-blind/main.go +++ b/examples/llava-blind/main.go @@ -1,8 +1,6 @@ // Package main demonstrates an example application of groq-go. package main -// url: https://cdnimg.webstaurantstore.com/images/products/large/87539/251494.jpg - import ( "context" "fmt" diff --git a/extensions/composio/.go-version b/extensions/composio/.go-version new file mode 100644 index 0000000..49e0a31 --- /dev/null +++ b/extensions/composio/.go-version @@ -0,0 +1 @@ +1.23.1 diff --git a/extensions/composio/README.md b/extensions/composio/README.md new file mode 100644 index 0000000..96f2332 --- /dev/null +++ b/extensions/composio/README.md @@ -0,0 +1,3 @@ +# composio + +Compose AI is a powerful tool for creating complex and high-quality compositions of ai tools. This package provides a client for the composio api easily accessible through the groq-go library. diff --git a/extensions/composio/auth.go b/extensions/composio/auth.go new file mode 100644 index 0000000..4496185 --- /dev/null +++ b/extensions/composio/auth.go @@ -0,0 +1,98 @@ +package composio + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +type ( + // Auther is an interface for composio auth. + Auther interface { + GetConnectedAccounts(ctx context.Context, opts ...AuthOption) ([]ConnectedAccount, error) + } + // ConnectedAccount represents a composio connected account. + // + // Gotten from similar url to: https://backend.composio.dev/api/v1/connectedAccounts?user_uuid=default&showActiveOnly=true + ConnectedAccount struct { + IntegrationID string `json:"integrationId"` + ConnectionParams struct { + Scope string `json:"scope"` + Scopes []string `json:"scopes"` + BaseURL string `json:"base_url"` + ClientID string `json:"client_id"` + TokenType string `json:"token_type"` + RedirectURL string `json:"redirectUrl"` + AccessToken string `json:"access_token"` + CallbackURL string `json:"callback_url"` + ClientSecret string `json:"client_secret"` + CodeVerifier string `json:"code_verifier"` + FinalRedirectURI string `json:"finalRedirectUri"` + } `json:"connectionParams"` + IsDisabled bool `json:"isDisabled"` + ID string `json:"id"` + MemberID string `json:"memberId"` + ClientUniqueUserID string `json:"clientUniqueUserId"` + Status string `json:"status"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Member struct { + ID string `json:"id"` + ClientID string `json:"clientId"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + Metadata any `json:"metadata"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + DeletedAt any `json:"deletedAt"` + } `json:"member"` + AppUniqueID string `json:"appUniqueId"` + AppName string `json:"appName"` + Logo string `json:"logo"` + IntegrationIsDisabled bool `json:"integrationIsDisabled"` + IntegrationDisabledReason string `json:"integrationDisabledReason"` + InvocationCount string `json:"invocationCount"` + } +) + +// GetConnectedAccounts returns the connected accounts for the composio client. +func (c *Composio) GetConnectedAccounts( + ctx context.Context, + opts ...AuthOption, +) ([]ConnectedAccount, error) { + uri := fmt.Sprintf("%s/v1/connectedAccounts", c.baseURL) + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + urlValues := u.Query() + urlValues.Add("user_uuid", "default") + urlValues.Add("showActiveOnly", "true") + for _, opt := range opts { + opt(&urlValues) + } + u.RawQuery = urlValues.Encode() + uri = u.String() + c.logger.Debug("auth", "url", uri) + req, err := builders.NewRequest( + ctx, + c.header, + http.MethodGet, + uri, + builders.WithBody(nil), + ) + if err != nil { + return nil, err + } + var caItems struct { + Items []ConnectedAccount `json:"items"` + } + err = c.doRequest(req, &caItems) + return caItems.Items, err +} diff --git a/extensions/composio/auth_test.go b/extensions/composio/auth_test.go new file mode 100644 index 0000000..348616f --- /dev/null +++ b/extensions/composio/auth_test.go @@ -0,0 +1,29 @@ +package composio_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/composio" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestUnitGetConnectedAccounts is an Unit test using a real composio server and api key. +func TestUnitGetConnectedAccounts(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + key, err := test.GetAPIKey("COMPOSIO_API_KEY") + a.NoError(err) + client, err := composio.NewComposer( + key, + composio.WithLogger(test.DefaultLogger), + ) + a.NoError(err) + ts, err := client.GetConnectedAccounts(ctx) + a.NoError(err) + a.NotEmpty(ts) +} diff --git a/extensions/composio/composio.go b/extensions/composio/composio.go new file mode 100644 index 0000000..8170e4e --- /dev/null +++ b/extensions/composio/composio.go @@ -0,0 +1,91 @@ +package composio + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + composioBaseURL = "https://backend.composio.dev/api" +) + +type ( + // Composer is an interface for composio client. + Composer interface { + Tooler + Runner + Auther + } + // Composio is a composio client. + Composio struct { + apiKey string + client *http.Client + logger *slog.Logger + header builders.Header + baseURL string + } + // Integration represents a composio integration. + Integration struct { + Name string `json:"name"` + ID int `json:"id"` + } +) + +// NewComposer creates a new composio client. +func NewComposer(apiKey string, opts ...Option) (Composer, error) { + c := &Composio{ + apiKey: apiKey, + header: builders.Header{SetCommonHeaders: func(r *http.Request) { + r.Header.Set("X-API-Key", apiKey) + }}, + baseURL: composioBaseURL, + client: http.DefaultClient, + logger: slog.Default(), + } + for _, opt := range opts { + opt(c) + } + return c, nil +} + +func (c *Composio) doRequest(req *http.Request, v interface{}) error { + req.Header.Set("Accept", "application/json") + contentType := req.Header.Get("Content-Type") + if contentType == "" { + req.Header.Set("Content-Type", "application/json") + } + res, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer res.Body.Close() + if res.StatusCode < http.StatusOK || + res.StatusCode >= http.StatusBadRequest { + bodyText, _ := io.ReadAll(res.Body) + return fmt.Errorf("request failed: %s\nbody: %s", res.Status, bodyText) + } + if v == nil { + return nil + } + switch o := v.(type) { + case *string: + b, err := io.ReadAll(res.Body) + if err != nil { + return err + } + *o = string(b) + return nil + default: + err = json.NewDecoder(res.Body).Decode(v) + if err != nil { + bodyText, _ := io.ReadAll(res.Body) + return fmt.Errorf("failed to decode response: %w\nbody: %s", err, bodyText) + } + return nil + } +} diff --git a/extensions/composio/composio_test.go b/extensions/composio/composio_test.go new file mode 100644 index 0000000..fed4439 --- /dev/null +++ b/extensions/composio/composio_test.go @@ -0,0 +1 @@ +package composio diff --git a/extensions/composio/doc.go b/extensions/composio/doc.go new file mode 100644 index 0000000..6ecc858 --- /dev/null +++ b/extensions/composio/doc.go @@ -0,0 +1,2 @@ +// Package composio provides a composio client for groq-go. +package composio diff --git a/extensions/composio/execute.go b/extensions/composio/execute.go new file mode 100644 index 0000000..3473e53 --- /dev/null +++ b/extensions/composio/execute.go @@ -0,0 +1,82 @@ +package composio + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/builders" +) + +type ( + // Runner is an interface for composio run. + Runner interface { + Run(ctx context.Context, response groq.ChatCompletionResponse) ( + []groq.ChatCompletionMessage, error) + } + request struct { + ConnectedAccountID string `json:"connectedAccountId"` + EntityID string `json:"entityId"` + AppName string `json:"appName"` + Input map[string]any `json:"input"` + Text string `json:"text,omitempty"` + AuthConfig map[string]any `json:"authConfig,omitempty"` + } +) + +// Run runs the composio client on a chat completion response. +func (c *Composio) Run( + ctx context.Context, + response groq.ChatCompletionResponse, +) ([]groq.ChatCompletionMessage, error) { + var respH []groq.ChatCompletionMessage + if response.Choices[0].FinishReason != groq.FinishReasonFunctionCall && + response.Choices[0].FinishReason != "tool_calls" { + return nil, fmt.Errorf("Not a function call") + } + connectedAccount, err := c.GetConnectedAccounts(ctx, WithShowActiveOnly(true)) + if err != nil { + return nil, err + } + c.logger.Debug("connected accounts", "accounts", connectedAccount) + for _, toolCall := range response.Choices[0].Message.ToolCalls { + var args map[string]any + if json.Valid([]byte(toolCall.Function.Arguments)) { + err = json.Unmarshal([]byte(toolCall.Function.Arguments), &args) + if err != nil { + return nil, err + } + c.logger.Debug("arguments", "args", args) + } + req, err := builders.NewRequest( + ctx, + c.header, + http.MethodPost, + fmt.Sprintf("%s/v2/actions/%s/execute", c.baseURL, toolCall.Function.Name), + builders.WithBody(&request{ + ConnectedAccountID: connectedAccount[0].ID, + EntityID: "default", + AppName: toolCall.Function.Name, + Input: args, + Text: "", + AuthConfig: map[string]any{}, + }), + ) + if err != nil { + return nil, err + } + var body string + err = c.doRequest(req, &body) + if err != nil { + return nil, fmt.Errorf("failed to do request: %w", err) + } + respH = append(respH, groq.ChatCompletionMessage{ + Content: string(body), + Name: toolCall.ID, + Role: groq.ChatMessageRoleFunction, + }) + } + return respH, nil +} diff --git a/extensions/composio/execute_test.go b/extensions/composio/execute_test.go new file mode 100644 index 0000000..98b7309 --- /dev/null +++ b/extensions/composio/execute_test.go @@ -0,0 +1,51 @@ +package composio + +import ( + "context" + "os" + "testing" + + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +func TestRun(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + key, err := test.GetAPIKey("COMPOSIO_API_KEY") + a.NoError(err) + client, err := NewComposer( + key, + WithLogger(test.DefaultLogger), + ) + a.NoError(err) + ts, err := client.GetTools( + ctx, WithApp("GITHUB"), WithUseCase("StarRepo")) + a.NoError(err) + a.NotEmpty(ts) + groqClient, err := groq.NewClient( + os.Getenv("GROQ_KEY"), + ) + a.NoError(err, "NewClient error") + response, err := groqClient.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: groq.ModelLlama3Groq8B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Star the facebookresearch/spiritlm repository on GitHub", + }, + }, + MaxTokens: 2000, + Tools: ts, + }) + a.NoError(err) + a.NotEmpty(response.Choices[0].Message.ToolCalls) + resp2, err := client.Run(ctx, response) + a.NoError(err) + a.NotEmpty(resp2) + t.Logf("%+v\n", resp2) +} diff --git a/extensions/composio/options.go b/extensions/composio/options.go new file mode 100644 index 0000000..554c948 --- /dev/null +++ b/extensions/composio/options.go @@ -0,0 +1,71 @@ +package composio + +import ( + "fmt" + "log/slog" + "net/url" + "strings" +) + +type ( + // Option is an option for the composio client. + // + // WithLogger sets the logger for the composio client. + Option func(*Composio) + + // ToolsOption is an option for the tools request. + ToolsOption func(*url.Values) + + // AuthOption is an option for the auth request. + AuthOption func(*url.Values) +) + +// Composer Options + +// WithLogger sets the logger for the composio client. +func WithLogger(logger *slog.Logger) Option { + return func(c *Composio) { c.logger = logger } +} + +// WithBaseURL sets the base URL for the composio client. +func WithBaseURL(baseURL string) Option { + return func(c *Composio) { c.baseURL = baseURL } +} + +// Get Tool Options + +// WithTags sets the tags for the tools request. +func WithTags(tags ...string) ToolsOption { + return func(u *url.Values) { u.Add("tags", strings.Join(tags, ",")) } +} + +// WithApp sets the app for the tools request. +func WithApp(app string) ToolsOption { + return func(u *url.Values) { u.Add("appNames", app) } +} + +// WithEntityID sets the entity id for the tools request. +func WithEntityID(entityID string) ToolsOption { + return func(u *url.Values) { u.Add("user_uuid", entityID) } +} + +// WithUseCase sets the use case for the tools request. +func WithUseCase(useCase string) ToolsOption { + return func(u *url.Values) { u.Add("useCase", useCase) } +} + +// Auth Options + +// WithShowActiveOnly sets the show active only for the auth request. +func WithShowActiveOnly(showActiveOnly bool) AuthOption { + return func(u *url.Values) { + u.Set("showActiveOnly", fmt.Sprintf("%t", showActiveOnly)) + } +} + +// WithUserUUID sets the user uuid for the auth request. +func WithUserUUID(userUUID string) AuthOption { + return func(u *url.Values) { + u.Set("user_uuid", userUUID) + } +} diff --git a/extensions/composio/test.sh b/extensions/composio/test.sh new file mode 100644 index 0000000..a7afc08 --- /dev/null +++ b/extensions/composio/test.sh @@ -0,0 +1,4 @@ +echo "APIKEY: $COMPOSIO_API_KEY" +apikey=$COMPOSIO_API_KEY +curl --request GET --url https://backend.composio.dev/api/v2/actions?tags=Authentication --header 'X-API-Key: '$apikey \ +> composio.json diff --git a/extensions/composio/tools.go b/extensions/composio/tools.go new file mode 100644 index 0000000..eeba205 --- /dev/null +++ b/extensions/composio/tools.go @@ -0,0 +1,109 @@ +package composio + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/tools" +) + +type ( + // Tooler is an interface for retreiving composio tools. + Tooler interface { + GetTools(ctx context.Context, opts ...ToolsOption) ( + []tools.Tool, error) + } + // Tool represents a composio tool as returned by the api. + Tool struct { + Name string `json:"name"` + Enum string `json:"enum"` + Tags []string `json:"tags"` + Logo string `json:"logo"` + AppID string `json:"appId"` + AppName string `json:"appName"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + Parameters tools.FunctionParameters `json:"parameters"` + Response struct { + Properties struct { + Data struct { + Title string `json:"title"` + Type string `json:"type"` + } `json:"data"` + Successful struct { + Description string `json:"description"` + Title string `json:"title"` + Type string `json:"type"` + } `json:"successful"` + Error struct { + AnyOf []struct { + Type string `json:"type"` + } `json:"anyOf"` + Default any `json:"default"` + Description string `json:"description"` + Title string `json:"title"` + } `json:"error"` + } `json:"properties"` + Required []string `json:"required"` + Title string `json:"title"` + Type string `json:"type"` + } `json:"response"` + Deprecated bool `json:"deprecated"` + DisplayName0 string `json:"display_name"` + } +) + +// GetTools returns the tools for the composio client. +func (c *Composio) GetTools( + ctx context.Context, + opts ...ToolsOption, +) ([]tools.Tool, error) { + uri := fmt.Sprintf("%s/v1/actions", c.baseURL) + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + q := u.Query() + for _, opt := range opts { + opt(&q) + } + u.RawQuery = q.Encode() + uri = u.String() + c.logger.Debug("tools", "uri", uri) + req, err := builders.NewRequest( + ctx, + c.header, + http.MethodGet, + uri, + builders.WithBody(nil), + ) + if err != nil { + return nil, err + } + var items struct { + Tools []Tool `json:"items"` + } + err = c.doRequest(req, &items) + if err != nil { + return nil, err + } + c.logger.Debug("tools", "toolslen", len(items.Tools)) + return groqTools(items.Tools), nil +} +func groqTools(localTools []Tool) []tools.Tool { + groqTools := make([]tools.Tool, 0, len(localTools)) + for _, tool := range localTools { + groqTools = append(groqTools, tools.Tool{ + Function: tools.FunctionDefinition{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + }, + Type: tools.ToolTypeFunction, + }) + } + return groqTools +} diff --git a/extensions/composio/tools_test.go b/extensions/composio/tools_test.go new file mode 100644 index 0000000..5e1cd04 --- /dev/null +++ b/extensions/composio/tools_test.go @@ -0,0 +1,28 @@ +package composio + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestGetTools tests the ability of the composio client to get tools. +func TestGetTools(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + key, err := test.GetAPIKey("COMPOSIO_API_KEY") + a.NoError(err) + client, err := NewComposer( + key, + WithLogger(test.DefaultLogger), + ) + a.NoError(err) + ts, err := client.GetTools(ctx, WithApp("GITHUB")) + a.NoError(err) + a.NotEmpty(ts) +} diff --git a/extensions/e2b/doc.go b/extensions/e2b/doc.go new file mode 100644 index 0000000..8a2f191 --- /dev/null +++ b/extensions/e2b/doc.go @@ -0,0 +1,2 @@ +// Package e2b provides an e2b client for groq-go. +package e2b diff --git a/extensions/e2b/model.go b/extensions/e2b/model.go index b4d04c4..acecbed 100644 --- a/extensions/e2b/model.go +++ b/extensions/e2b/model.go @@ -7,17 +7,6 @@ import ( ) type ( - // Requester is an interface for an instance that sends rpc requests. - // - // Implementations should be conccurent safe. - Requester interface { - Write( - ctx context.Context, - method Method, - params []any, - respCh chan []byte, - ) - } // Receiver is an interface for a constantly receiving instance. // // Implementations should be conccurent safe. @@ -25,40 +14,51 @@ type ( Read(ctx context.Context) error io.Closer } - // Identifier is an interface for a constantly running process to identify new request ids. + // Identifier is an interface for a constantly running process to + // identify new request ids. Identifier interface { Identify(ctx context.Context) } // Sandboxer is an interface for a sandbox. Sandboxer interface { - Lifer - } - // Processor is an interface for a process. - Processor interface { - Start( - ctx context.Context, - cmd string, - timeout time.Duration, - ) - Subscriber - } - // Lifer is an interface for keeping sandboxes alive. - Lifer interface { // KeepAlive keeps the underlying interface alive. // // If the context is cancelled before requesting the timeout, // the error will be ctx.Err(). KeepAlive(ctx context.Context, timeout time.Duration) error + // NewProcess creates a new process. + NewProcess( + cmd string, + ) (*Processor, error) + + // Write writes a file to the filesystem. + Write( + ctx context.Context, + method Method, + params []any, + respCh chan<- []byte, + ) + // Read reads a file from the filesystem. + Read( + ctx context.Context, + path string, + ) (string, error) } - // Subscriber is an interface for an instance that can subscribe to an event. - Subscriber interface { - Subscribe( + // Processor is an interface for a process. + Processor interface { + Start( ctx context.Context, - event ProcessEvents, - eCh chan<- Event, + cmd string, + timeout time.Duration, ) + SubscribeStdout() (events chan Event, err error) + SubscribeStderr() (events chan Event, err error) } // Watcher is an interface for a instance that can watch a filesystem. Watcher interface { + Watch( + ctx context.Context, + path string, + ) (<-chan Event, error) } ) diff --git a/extensions/e2b/options.go b/extensions/e2b/options.go new file mode 100644 index 0000000..4a754b9 --- /dev/null +++ b/extensions/e2b/options.go @@ -0,0 +1,45 @@ +package e2b + +import ( + "log/slog" + "net/http" +) + +// E2B Options + +// WithBaseURL sets the base URL for the e2b sandbox. +func WithBaseURL(baseURL string) Option { + return func(s *Sandbox) { s.baseURL = baseURL } +} + +// WithClient sets the client for the e2b sandbox. +func WithClient(client *http.Client) Option { + return func(s *Sandbox) { s.client = client } +} + +// WithLogger sets the logger for the e2b sandbox. +func WithLogger(logger *slog.Logger) Option { + return func(s *Sandbox) { s.logger = logger } +} + +// WithTemplate sets the template for the e2b sandbox. +func WithTemplate(template SandboxTemplate) Option { + return func(s *Sandbox) { s.Template = template } +} + +// WithMetaData sets the meta data for the e2b sandbox. +func WithMetaData(metaData map[string]string) Option { + return func(s *Sandbox) { s.Metadata = metaData } +} + +// WithCwd sets the current working directory. +func WithCwd(cwd string) Option { + return func(s *Sandbox) { s.Cwd = cwd } +} + +// WithWsURL sets the websocket url for the e2b sandbox. +func WithWsURL(wsURL func(s *Sandbox) string) Option { + return func(s *Sandbox) { s.wsURL = wsURL } +} + +// Process Options diff --git a/extensions/e2b/sandbox.go b/extensions/e2b/sandbox.go index 4e32eb8..480204e 100644 --- a/extensions/e2b/sandbox.go +++ b/extensions/e2b/sandbox.go @@ -9,7 +9,6 @@ import ( "log/slog" "math/rand" "net/http" - "net/url" "sync" "time" @@ -33,30 +32,26 @@ type ( ClientID string `json:"clientID"` // ClientID of the sandbox. Cwd string `json:"cwd"` // Cwd is the sandbox's current working directory. - logger *slog.Logger `json:"-"` // logger is the sandbox's logger. - apiKey string `json:"-"` // apiKey is the sandbox's api key. - baseURL string `json:"-"` // baseAPIURL is the base api url of the sandbox. - httpScheme string `json:"-"` // httpScheme is the sandbox's http scheme. - client *http.Client `json:"-"` // client is the sandbox's http client. - header builders.Header `json:"-"` // header is the sandbox's request header builder. - ws *websocket.Conn `json:"-"` // ws is the sandbox's websocket connection. - Map sync.Map `json:"-"` // Map is the map of the sandbox. - idCh chan int `json:"-"` // idCh is the channel to generate ids for requests. - toolW ToolingWrapper `json:"-"` // toolW is the tooling wrapper for the sandbox. + logger *slog.Logger `json:"-"` // logger is the sandbox's logger. + apiKey string `json:"-"` // apiKey is the sandbox's api key. + baseURL string `json:"-"` // baseAPIURL is the base api url of the sandbox. + client *http.Client `json:"-"` // client is the sandbox's http client. + header builders.Header `json:"-"` // header is the sandbox's request header builder. + ws *websocket.Conn `json:"-"` // ws is the sandbox's websocket connection. + wsURL func(s *Sandbox) string `json:"-"` // wsURL is the sandbox's websocket url. + Map sync.Map `json:"-"` // Map is the map of the sandbox. + idCh chan int `json:"-"` // idCh is the channel to generate ids for requests. + toolW ToolingWrapper `json:"-"` // toolW is the tooling wrapper for the sandbox. } // Process is a process in the sandbox. Process struct { sb *Sandbox // sb is the sandbox the process belongs to. + ctx context.Context // ctx is the context for the process. id string // ID is process id. cmd string // cmd is process's command. Cwd string // cwd is process's current working directory. Env map[string]string // env is process's environment variables. } - // SubscribeParams is the params for subscribing to a process event. - SubscribeParams struct { - Event ProcessEvents // Event is the event to subscribe to. - Ch chan<- Event // Ch is the channel to write the event to. - } // Option is an option for the sandbox. Option func(*Sandbox) // Event is a file system event. @@ -128,13 +123,12 @@ const ( rpc = "2.0" charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - defaultBaseURL = "api.e2b.dev" + defaultBaseURL = "https://api.e2b.dev" defaultWSScheme = "wss" wsRoute = "/ws" fileRoute = "/file" sandboxesRoute = "/sandboxes" // (GET/POST /sandboxes) deleteSandboxRoute = "/sandboxes/" // (DELETE /sandboxes/:id) - defaultHTTPScheme = "https" filesystemWrite Method = "filesystem_write" filesystemRead Method = "filesystem_read" @@ -162,11 +156,13 @@ func NewSandbox( Metadata: map[string]string{ "sdk": "groq-go v1", }, - client: http.DefaultClient, - logger: slog.New(slog.NewJSONHandler(io.Discard, nil)), - httpScheme: defaultHTTPScheme, - idCh: make(chan int), - toolW: defaultToolWrapper, + client: http.DefaultClient, + logger: slog.New(slog.NewJSONHandler(io.Discard, nil)), + idCh: make(chan int), + toolW: defaultToolWrapper, + wsURL: func(s *Sandbox) string { + return fmt.Sprintf("wss://49982-%s-%s.e2b.dev/ws", s.ID, s.ClientID) + }, } for _, opt := range opts { opt(&sb) @@ -178,7 +174,7 @@ func NewSandbox( } req, err := builders.NewRequest( ctx, sb.header, http.MethodPost, - fmt.Sprintf("%s://%s%s", sb.httpScheme, sb.baseURL, sandboxesRoute), + fmt.Sprintf("%s%s", sb.baseURL, sandboxesRoute), builders.WithBody(&sb), ) if err != nil { @@ -188,7 +184,7 @@ func NewSandbox( if err != nil { return &sb, err } - sb.ws, _, err = websocket.DefaultDialer.Dial(sb.wsURL().String(), nil) + sb.ws, _, err = websocket.DefaultDialer.Dial(sb.wsURL(&sb), nil) if err != nil { return &sb, err } @@ -206,7 +202,7 @@ func NewSandbox( func (s *Sandbox) KeepAlive(ctx context.Context, timeout time.Duration) error { req, err := builders.NewRequest( ctx, s.header, http.MethodPost, - fmt.Sprintf("%s://%s/sandboxes/%s/refreshes", s.httpScheme, s.baseURL, s.ID), + fmt.Sprintf("%s/sandboxes/%s/refreshes", s.baseURL, s.ID), builders.WithBody(struct { Duration int `json:"duration"` }{Duration: int(timeout.Seconds())}), @@ -231,9 +227,8 @@ func (s *Sandbox) Reconnect(ctx context.Context) (err error) { if err := s.ws.Close(); err != nil { return err } - u := s.wsURL() - s.logger.Debug("reconnecting to sandbox", "url", u.String()) - s.ws, _, err = websocket.DefaultDialer.Dial(u.String(), nil) + urlu := fmt.Sprintf("wss://49982-%s-%s.e2b.dev/ws", s.ID, s.ClientID) + s.ws, _, err = websocket.DefaultDialer.Dial(urlu, nil) if err != nil { return err } @@ -250,7 +245,7 @@ func (s *Sandbox) Reconnect(ctx context.Context) (err error) { func (s *Sandbox) Stop(ctx context.Context) error { req, err := builders.NewRequest( ctx, s.header, http.MethodDelete, - fmt.Sprintf("%s://%s%s%s", s.httpScheme, s.baseURL, deleteSandboxRoute, s.ID), + fmt.Sprintf("%s%s%s", s.baseURL, deleteSandboxRoute, s.ID), builders.WithBody(interface{}(nil)), ) if err != nil { @@ -449,9 +444,16 @@ func (p *Process) Start(ctx context.Context) (err error) { p.Env = map[string]string{"PYTHONUNBUFFERED": "1"} } respCh := make(chan []byte) - if err = p.sb.writeRequest(ctx, processStart, []any{p.id, p.cmd, p.Env, p.Cwd}, respCh); err != nil { + err = p.sb.writeRequest( + ctx, + processStart, + []any{p.id, p.cmd, p.Env, p.Cwd}, + respCh, + ) + if err != nil { return err } + p.ctx = ctx select { case body := <-respCh: res, err := decodeResponse[string, APIError](body) @@ -482,76 +484,78 @@ func (p *Process) Done() <-chan struct{} { return rCh.(<-chan struct{}) } +// SubscribeStdout subscribes to the process's stdout. +func (p *Process) SubscribeStdout(events chan Event) (err error) { + err = p.subscribe(p.ctx, OnStdout, events) + return +} + +// SubscribeStderr subscribes to the process's stderr. +func (p *Process) SubscribeStderr(events chan Event) (err error) { + err = p.subscribe(p.ctx, OnStderr, events) + return +} + +// SubscribeExit subscribes to the process's exit. +func (p *Process) SubscribeExit(events chan Event) (err error) { + err = p.subscribe(p.ctx, OnExit, events) + return +} + // Subscribe subscribes to a process event. // -// It creates a go routine to read the process events. -func (p *Process) Subscribe( +// It creates a go routine to read the process events into the provided channel. +func (p *Process) subscribe( ctx context.Context, event ProcessEvents, eCh chan<- Event, ) error { - respCh := make(chan []byte) - err := p.sb.writeRequest(ctx, processSubscribe, []any{event, p.id}, respCh) - if err != nil { - return err - } - res, err := decodeResponse[string, APIError](<-respCh) - if err != nil { - return err - } - if res.Error.Code != 0 { - return fmt.Errorf("process subscribe failed(%d): %s", res.Error.Code, res.Error.Message) - } - eventByCh := make(chan []byte) - p.sb.Map.Store(res.Result, eventByCh) - for { - select { - case eventBd := <-eventByCh: - var event Event - err = json.Unmarshal(eventBd, &event) - if err != nil { - return err - } - if event.Error != "" { - return fmt.Errorf("failed to read event: %s", event.Error) - } - if event.Params.Subscription != res.Result { - return fmt.Errorf("subscription id mismatch") - } - eCh <- event - case <-ctx.Done(): - close(eventByCh) - p.sb.Map.Delete(res.Result) - finishCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - p.sb.logger.Debug("unsubscribing from process", "event", event, "id", res.Result) - err = p.sb.writeRequest(finishCtx, processUnsubscribe, []any{res.Result}, respCh) - if err != nil { - return err - } - unsubRes, err := decodeResponse[bool, string](<-respCh) - if err != nil { - return err - } - if unsubRes.Error != "" || !unsubRes.Result { - return fmt.Errorf("failed to unsubscribe from process: %s", unsubRes.Error) + errCh := make(chan error) + go func(errCh chan error) { + respCh := make(chan []byte) + defer close(respCh) + err := p.sb.writeRequest(ctx, processSubscribe, []any{event, p.id}, respCh) + if err != nil || respCh == nil { + errCh <- err + } + res, err := decodeResponse[string, any](<-respCh) + errCh <- err + if err != nil { + return + } + p.sb.Map.Store(res.Result, respCh) + for { + select { + case eventBd := <-respCh: + p.sb.logger.Debug("eventByCh", "event", string(eventBd)) + var event Event + _ = json.Unmarshal(eventBd, &event) + if event.Error != "" { + p.sb.logger.Debug("failed to read event", "error", event.Error) + continue + } + if event.Params.Subscription != res.Result { + p.sb.logger.Debug("subscription id mismatch", "expected", res.Result, "got", event.Params.Subscription) + continue + } + eCh <- event + case <-ctx.Done(): + p.sb.Map.Delete(res.Result) + finishCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + p.sb.logger.Debug("unsubscribing from process", "event", event, "id", res.Result) + _ = p.sb.writeRequest(finishCtx, processUnsubscribe, []any{res.Result}, respCh) + unsubRes, _ := decodeResponse[bool, string](<-respCh) + if unsubRes.Error != "" || !unsubRes.Result { + p.sb.logger.Debug("failed to unsubscribe from process", "error", unsubRes.Error) + } + return + case <-p.Done(): + return } - return nil - // TODO: make this a timeout that comes from a function param. - case <-p.Done(): - return nil } - } -} -func (s *Sandbox) wsURL() *url.URL { - return &url.URL{ - Scheme: defaultWSScheme, - Host: fmt.Sprintf("49982-%s-%s.e2b.dev", - s.ID, - s.ClientID, - ), - Path: wsRoute, - } + }(errCh) + return <-errCh } func (s *Sandbox) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json") @@ -583,37 +587,6 @@ func (s *Sandbox) sendRequest(req *http.Request, v interface{}) error { return json.NewDecoder(res.Body).Decode(v) } } - -// WithBaseURL sets the base URL for the e2b sandbox. -func WithBaseURL(baseURL string) Option { - return func(s *Sandbox) { s.baseURL = baseURL } -} - -// WithClient sets the client for the e2b sandbox. -func WithClient(client *http.Client) Option { - return func(s *Sandbox) { s.client = client } -} - -// WithLogger sets the logger for the e2b sandbox. -func WithLogger(logger *slog.Logger) Option { - return func(s *Sandbox) { s.logger = logger } -} - -// WithTemplate sets the template for the e2b sandbox. -func WithTemplate(template SandboxTemplate) Option { - return func(s *Sandbox) { s.Template = template } -} - -// WithMetaData sets the meta data for the e2b sandbox. -func WithMetaData(metaData map[string]string) Option { - return func(s *Sandbox) { s.Metadata = metaData } -} - -// WithCwd sets the current working directory. -func WithCwd(cwd string) Option { - return func(s *Sandbox) { s.Cwd = cwd } -} - func decodeResponse[T any, Q any](body []byte) (*Response[T, Q], error) { decResp := new(Response[T, Q]) err := json.Unmarshal(body, decResp) @@ -635,50 +608,63 @@ func (s *Sandbox) identify(ctx context.Context) { } } func (s *Sandbox) read(ctx context.Context) (err error) { + var body []byte defer func() { err = s.ws.Close() }() + msgCh := make(chan []byte, 10) for { select { - case <-ctx.Done(): - return ctx.Err() - default: - _, body, err := s.ws.ReadMessage() - if err != nil { - return err - } + case body = <-msgCh: var decResp decResp err = json.Unmarshal(body, &decResp) if err != nil { return err } - s.logger.Debug("read", - "id", decResp.ID, - "body", string(body), - "sandbox", s.ID, - ) if decResp.Params.Subscription != "" { toR, ok := s.Map.Load(decResp.Params.Subscription) if !ok { - s.logger.Debug("subscription not found", "id", decResp.Params.Subscription) + msgCh <- body + continue } toRCh, ok := toR.(chan []byte) if !ok { - s.logger.Debug("subscription not found", "id", decResp.Params.Subscription) + msgCh <- body + continue } + s.logger.Debug("read", + "subscription", decResp.Params.Subscription, + "body", body, + "sandbox", s.ID, + ) toRCh <- body - continue } - // response has an id - toR, ok := s.Map.Load(decResp.ID) - if !ok { - s.logger.Debug("response not found", "id", decResp.ID) + if decResp.ID != 0 { + toR, ok := s.Map.Load(decResp.ID) + if !ok { + msgCh <- body + continue + } + toRCh, ok := toR.(chan []byte) + if !ok { + msgCh <- body + continue + } + s.logger.Debug("read", + "id", decResp.ID, + "body", body, + "sandbox", s.ID, + ) + toRCh <- body } - toRCh, ok := toR.(chan []byte) - if !ok { - s.logger.Debug("responsech not found", "id", decResp.ID) + case <-ctx.Done(): + return ctx.Err() + default: + _, body, err := s.ws.ReadMessage() + if err != nil { + return err } - toRCh <- body + msgCh <- body } } } diff --git a/extensions/e2b/sandbox_test.go b/extensions/e2b/sandbox_test.go index 3850d7a..385ebf1 100644 --- a/extensions/e2b/sandbox_test.go +++ b/extensions/e2b/sandbox_test.go @@ -1,200 +1,174 @@ -package e2b_test +package e2b import ( "context" "encoding/json" - "log/slog" - "os" + "net/http" + "net/http/httptest" "strings" + "sync" "testing" - "time" - "github.com/conneroisu/groq-go/extensions/e2b" "github.com/conneroisu/groq-go/pkg/test" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) -var ( - defaultLogger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == "time" { - return slog.Attr{} - } - if a.Key == "level" { - return slog.Attr{} - } - if a.Key == slog.SourceKey { - str := a.Value.String() - split := strings.Split(str, "/") - if len(split) > 2 { - a.Value = slog.StringValue(strings.Join(split[len(split)-2:], "/")) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "}", "", -1)) - } - a.Key = a.Value.String() - a.Value = slog.IntValue(0) - } - if a.Key == "body" { - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "/", "", -1)) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "\n", "", -1)) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "\"", "", -1)) - } - return a - }})) -) +var upgrader = websocket.Upgrader{} + +const subID = "test-sub-id" -func getapiKey(t *testing.T) string { - apiKey := os.Getenv("E2B_API_KEY") - if apiKey == "" { - t.Fail() +func echo(a *assert.Assertions) func(w http.ResponseWriter, r *http.Request) { + mu := sync.Mutex{} + return func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + a.NoError(err) + test.DefaultLogger.Debug("server read message", "msg", message) + req := decode(message) + switch req.Method { + case filesystemList: + err = c.WriteMessage(mt, encode(Response[[]LsResult, string]{ + ID: req.ID, + Error: "", + Result: []LsResult{ + { + Name: "hello.txt", + IsDir: false, + }, + }, + })) + a.NoError(err) + case filesystemRead: + err = c.WriteMessage(mt, encode(Response[string, string]{ + ID: req.ID, + Error: "", + Result: "hello", + })) + a.NoError(err) + case filesystemWrite: + err = c.WriteMessage(mt, encode(Response[string, string]{ + ID: req.ID, + Error: "", + Result: "hello", + })) + a.NoError(err) + case processStart: + err = c.WriteMessage(mt, encode(Response[string, APIError]{ + ID: req.ID, + Error: APIError{}, + Result: req.Params[0].(string), + })) + a.NoError(err) + case processSubscribe: + err = c.WriteMessage(mt, encode(Response[string, APIError]{ + ID: req.ID, + Error: APIError{}, + Result: subID, + })) + a.NoError(err) + err = c.WriteMessage(mt, encode(Event{ + Params: EventParams{ + Subscription: subID, + Result: EventResult{ + Type: "Stdout", + Line: "hello", + Timestamp: 0, + IsDirectory: false, + Error: "", + }, + }, + })) + a.NoError(err) + case filesystemMakeDir: + err = c.WriteMessage(mt, encode(Response[string, APIError]{ + ID: req.ID, + Error: APIError{}, + Result: "", + })) + a.NoError(err) + } + } } - return apiKey } -func TestPostSandbox(t *testing.T) { - if !test.IsUnitTest() { - t.Skip() - } - a := assert.New(t) - ctx := context.Background() - sb, err := e2b.NewSandbox( - ctx, - getapiKey(t), - e2b.WithLogger(defaultLogger), - ) - a.NoError(err, "NewSandbox error") - lsr, err := sb.Ls(ctx, ".") - a.NoError(err) - for _, name := range []string{"boot", "code", "dev", "etc", "home"} { - a.Contains(lsr, e2b.LsResult{ - Name: name, - IsDir: true, - }) +func encode(v any) []byte { + res, err := json.Marshal(v) + if err != nil { + panic(err) } - err = sb.Mkdir(ctx, "heelo") - a.NoError(err) - lsr, err = sb.Ls(ctx, "/") - a.NoError(err) - a.Contains(lsr, e2b.LsResult{ - Name: "heelo", - IsDir: true, - }) + return res } - -// TestWriteRead tests the Write and Read methods of the Sandbox. -func TestWriteRead(t *testing.T) { - if !test.IsUnitTest() { - t.Skip() +func decode(bod []byte) Request { + var req Request + err := json.Unmarshal(bod, &req) + if err != nil { + panic(err) } - filePath := "test.txt" - content := "Hello, world!" - a := assert.New(t) - ctx := context.Background() - sb, err := e2b.NewSandbox( - ctx, - getapiKey(t), - e2b.WithLogger(defaultLogger), - ) - a.NoError(err, "NewSandbox error") - err = sb.Write(ctx, filePath, []byte(content)) - a.NoError(err, "Write error") - readContent, err := sb.Read(ctx, filePath) - a.NoError(err, "Read error") - a.Equal(content, string(readContent), "Read content does not match written content") - readBytesContent, err := sb.ReadBytes(ctx, filePath) - a.NoError(err, "ReadBytes error") - a.Equal(content, string(readBytesContent), "ReadBytes content does not match written content") - err = sb.Stop(ctx) - a.NoError(err, "Stop error") + return req } -func TestCreateProcess(t *testing.T) { - if !test.IsUnitTest() { - t.Skip() - } +func TestNewSandbox(t *testing.T) { a := assert.New(t) - ctx := context.Background() - sb, err := e2b.NewSandbox( + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + srv := test.NewTestServer() + ts := srv.E2bTestServer() + wsts := httptest.NewServer(http.HandlerFunc(echo(a))) + id := "test-sandbox-id" + srv.RegisterHandler("/sandboxes", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write(encode(&Sandbox{ID: id})) + a.NoError(err) + }) + ts.Start() + u := "ws" + strings.TrimPrefix(wsts.URL, "http") + // Create a new sandbox. + sb, err := NewSandbox( ctx, - getapiKey(t), - e2b.WithLogger(defaultLogger), + test.GetTestToken(), + WithLogger(test.DefaultLogger), + WithBaseURL(ts.URL), + WithWsURL(func(_ *Sandbox) string { + return u + "/ws" + }), ) a.NoError(err, "NewSandbox error") - proc, err := sb.NewProcess("echo 'Hello World!'", - e2b.Process{ - Env: map[string]string{ - "FOO": "bar", - }, - }) - a.NoError(err, "could not create process") - err = proc.Start(ctx) - a.NoError(err) - proc, err = sb.NewProcess("sleep 2 && echo 'Hello World!'", e2b.Process{}) - a.NoError(err, "could not create process") - err = proc.Start(ctx) + a.NotNil(sb, "NewSandbox returned nil") + a.Equal(sb.ID, id) + + lsRes, err := sb.Ls(ctx, ".") a.NoError(err) - events := make(chan e2b.Event, 10) - ctx, cancel := context.WithTimeout(ctx, time.Second*6) - defer cancel() - err = proc.Subscribe(ctx, e2b.OnStdout, events) + a.NotEmpty(lsRes) + + err = sb.Mkdir(ctx, "hello") a.NoError(err) - event := <-events - jsonBytes, err := json.MarshalIndent(&event, "", " ") - if err != nil { - a.Error(err) - return - } - t.Logf("test got event: %s", string(jsonBytes)) -} -func TestFilesystemSubscribe(t *testing.T) { - if !test.IsUnitTest() { - t.Skip() - } - a := assert.New(t) - ctx := context.Background() - sb, err := e2b.NewSandbox( - ctx, - getapiKey(t), - e2b.WithLogger(defaultLogger), - e2b.WithCwd("/tmp"), - ) - a.NoError(err, "NewSandbox error") - // subscribe to a file - events := make(chan e2b.Event) - err = sb.Watch(ctx, "/tmp/", events) + err = sb.Write(ctx, "hello.txt", []byte("hello")) a.NoError(err) - go func() { - for event := range events { - jsonBytes, err := json.MarshalIndent(event, "", " ") - if err != nil { - a.Error(err) - return - } - t.Logf("test got event: %s", string(jsonBytes)) - } - }() - // create a file - err = sb.Write(ctx, "/tmp/file.txt", []byte("Hello World!")) + + readRes, err := sb.Read(ctx, "hello.txt") a.NoError(err) - err = sb.Write(ctx, "/tmp/file2.txt", []byte("Hello World!")) + a.Equal("hello", readRes) + + proc, err := sb.NewProcess("sleep 5 && echo 'hello world!'", Process{}) a.NoError(err) - time.Sleep(3 * time.Second) -} -func TestKeepAlive(t *testing.T) { - if !test.IsUnitTest() { - t.Skip() - } - a := assert.New(t) - ctx := context.Background() - sb, err := e2b.NewSandbox( - ctx, - getapiKey(t), - e2b.WithLogger(defaultLogger), - ) - a.NoError(err, "NewSandbox error") - err = sb.KeepAlive(ctx, time.Minute*2) + err = proc.Start(ctx) + a.NoError(err) + e := make(chan Event) + err = proc.SubscribeStdout(e) + a.NoError(err) + event := <-e + jsnBytes, err := json.MarshalIndent(&event, "", " ") a.NoError(err) + t.Logf("test got event: %s", string(jsnBytes)) } diff --git a/extensions/e2b/tools.go b/extensions/e2b/tools.go index 2a9dabd..9ff4c0c 100644 --- a/extensions/e2b/tools.go +++ b/extensions/e2b/tools.go @@ -7,14 +7,15 @@ import ( "fmt" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/tools" ) type ( // SbFn is a function that can be used to run a tool. SbFn func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) - // ToolingWrapper is a wrapper for groq.Tool that allows for custom functions working with a sandbox. + // ToolingWrapper is a wrapper for tools.Tool that allows for custom functions working with a sandbox. ToolingWrapper struct { - ToolMap map[*groq.Tool]SbFn + ToolMap map[*tools.Tool]SbFn } // Params are the parameters for any function call. Params struct { @@ -28,8 +29,8 @@ type ( ) // getTools returns the tools wrapped by the ToolWrapper. -func (t *ToolingWrapper) getTools() []groq.Tool { - tools := make([]groq.Tool, 0) +func (t *ToolingWrapper) getTools() []tools.Tool { + tools := make([]tools.Tool, 0) for tool := range t.ToolMap { tools = append(tools, *tool) } @@ -37,7 +38,7 @@ func (t *ToolingWrapper) getTools() []groq.Tool { } // GetTools returns the tools wrapped by the ToolWrapper. -func (s *Sandbox) GetTools() []groq.Tool { +func (s *Sandbox) GetTools() []tools.Tool { return s.toolW.getTools() } @@ -49,15 +50,19 @@ func (t *ToolingWrapper) GetToolFn(name string) (SbFn, error) { return fn, nil } } - return nil, fmt.Errorf("tool %s not found", name) + return nil, fmt.Errorf("Error running tool (does not exist) %s", name) } var ( defaultToolWrapper = ToolingWrapper{ ToolMap: toolMap, } - toolMap = map[*groq.Tool]SbFn{ - &mkdirTool: func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) { + toolMap = map[*tools.Tool]SbFn{ + &mkdirTool: func( + ctx context.Context, + s *Sandbox, + params *Params, + ) (groq.ChatCompletionMessage, error) { err := s.Mkdir(ctx, params.Path) if err != nil { return groq.ChatCompletionMessage{}, err @@ -68,7 +73,11 @@ var ( Name: "mkdir", }, nil }, - &lsTool: func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) { + &lsTool: func( + ctx context.Context, + s *Sandbox, + params *Params, + ) (groq.ChatCompletionMessage, error) { res, err := s.Ls(ctx, params.Path) if err != nil { return groq.ChatCompletionMessage{}, err @@ -83,7 +92,11 @@ var ( Name: "ls", }, nil }, - &readTool: func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) { + &readTool: func( + ctx context.Context, + s *Sandbox, + params *Params, + ) (groq.ChatCompletionMessage, error) { content, err := s.Read(ctx, params.Path) if err != nil { return groq.ChatCompletionMessage{}, err @@ -94,7 +107,11 @@ var ( Name: "read", }, nil }, - &writeTool: func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) { + &writeTool: func( + ctx context.Context, + s *Sandbox, + params *Params, + ) (groq.ChatCompletionMessage, error) { err := s.Write(ctx, params.Path, []byte(params.Data)) if err != nil { return groq.ChatCompletionMessage{}, err @@ -105,17 +122,21 @@ var ( Name: "write", }, nil }, - &startProcessTool: func(ctx context.Context, s *Sandbox, params *Params) (groq.ChatCompletionMessage, error) { + &startProcessTool: func( + ctx context.Context, + s *Sandbox, + params *Params, + ) (groq.ChatCompletionMessage, error) { proc, err := s.NewProcess(params.Cmd, Process{}) if err != nil { return groq.ChatCompletionMessage{}, err } - events := make(chan Event, 100) - err = proc.Subscribe(ctx, OnStdout, events) + e := make(chan Event, 10) + err = proc.SubscribeStdout(e) if err != nil { return groq.ChatCompletionMessage{}, err } - err = proc.Subscribe(ctx, OnStderr, events) + err = proc.SubscribeStderr(e) if err != nil { return groq.ChatCompletionMessage{}, err } @@ -129,7 +150,7 @@ var ( select { case <-ctx.Done(): return - case event := <-events: + case event := <-e: buf.Write([]byte(event.Params.Result.Line)) case <-proc.Done(): break @@ -144,14 +165,14 @@ var ( }, nil }, } - mkdirTool = groq.Tool{ - Type: groq.ToolTypeFunction, - Function: groq.FunctionDefinition{ + mkdirTool = tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ Name: "mkdir", Description: "Make a directory in the sandbox file system at a given path", - Parameters: groq.ParameterDefinition{ + Parameters: tools.FunctionParameters{ Type: "object", - Properties: map[string]groq.PropertyDefinition{ + Properties: map[string]tools.PropertyDefinition{ "path": { Type: "string", Description: "The path of the directory to create", @@ -162,14 +183,14 @@ var ( }, }, } - lsTool = groq.Tool{ - Type: groq.ToolTypeFunction, - Function: groq.FunctionDefinition{ + lsTool = tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ Name: "ls", Description: "List the files and directories in the sandbox file system at a given path", - Parameters: groq.ParameterDefinition{ + Parameters: tools.FunctionParameters{ Type: "object", - Properties: map[string]groq.PropertyDefinition{ + Properties: map[string]tools.PropertyDefinition{ "path": {Type: "string", Description: "The path of the directory to list", }, @@ -179,14 +200,14 @@ var ( }, }, } - readTool = groq.Tool{ - Type: groq.ToolTypeFunction, - Function: groq.FunctionDefinition{ + readTool = tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ Name: "read", Description: "Read the contents of a file in the sandbox file system at a given path", - Parameters: groq.ParameterDefinition{ + Parameters: tools.FunctionParameters{ Type: "object", - Properties: map[string]groq.PropertyDefinition{ + Properties: map[string]tools.PropertyDefinition{ "path": {Type: "string", Description: "The path of the file to read", }, @@ -196,14 +217,14 @@ var ( }, }, } - writeTool = groq.Tool{ - Type: groq.ToolTypeFunction, - Function: groq.FunctionDefinition{ + writeTool = tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ Name: "write", Description: "Write to a file in the sandbox file system at a given path", - Parameters: groq.ParameterDefinition{ + Parameters: tools.FunctionParameters{ Type: "object", - Properties: map[string]groq.PropertyDefinition{ + Properties: map[string]tools.PropertyDefinition{ "path": {Type: "string", Description: "The relative or absolute path of the file to write to", }, @@ -216,14 +237,14 @@ var ( }, }, } - startProcessTool = groq.Tool{ - Type: groq.ToolTypeFunction, - Function: groq.FunctionDefinition{ + startProcessTool = tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ Name: "start_process", Description: "Start a process in the sandbox.", - Parameters: groq.ParameterDefinition{ + Parameters: tools.FunctionParameters{ Type: "object", - Properties: map[string]groq.PropertyDefinition{ + Properties: map[string]tools.PropertyDefinition{ "cmd": {Type: "string", Description: "The command to run to start the process", }, @@ -267,8 +288,8 @@ func (s *Sandbox) RunTooling( func (s *Sandbox) runTool( ctx context.Context, - tool groq.Tool, - call groq.ToolCall, + tool tools.Tool, + call tools.ToolCall, ) (groq.ChatCompletionMessage, error) { s.logger.Debug("running tool", "tool", tool.Function.Name, "call", call.Function.Name) var params *Params @@ -282,7 +303,7 @@ func (s *Sandbox) runTool( fn, err := s.toolW.GetToolFn(tool.Function.Name) if err != nil { return groq.ChatCompletionMessage{ - Content: fmt.Sprintf("Error running tool (does not exist) %s: %s", tool.Function.Name, err.Error()), + Content: err.Error(), Role: groq.ChatMessageRoleFunction, Name: tool.Function.Name, }, err diff --git a/extensions/e2b/tools_test.go b/extensions/e2b/tools_test.go index 740646b..8b69a23 100644 --- a/extensions/e2b/tools_test.go +++ b/extensions/e2b/tools_test.go @@ -2,9 +2,7 @@ package e2b import ( "context" - "log/slog" "os" - "strings" "testing" "github.com/conneroisu/groq-go" @@ -12,34 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - defaultLogger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == "time" { - return slog.Attr{} - } - if a.Key == "level" { - return slog.Attr{} - } - if a.Key == "source" { - str := a.Value.String() - split := strings.Split(str, "/") - if len(split) > 2 { - a.Value = slog.StringValue(strings.Join(split[len(split)-2:], "/")) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "}", "", -1)) - } - } - if a.Key == "body" { - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "/", "", -1)) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "\n", "", -1)) - a.Value = slog.StringValue(strings.Replace(a.Value.String(), "\"", "", -1)) - } - return a - }})) -) - func getapiKey(t *testing.T, val string) string { apiKey := os.Getenv(val) if apiKey == "" { @@ -57,7 +27,7 @@ func TestSandboxTooling(t *testing.T) { sb, err := NewSandbox( ctx, getapiKey(t, "E2B_API_KEY"), - WithLogger(defaultLogger), + WithLogger(test.DefaultLogger), WithCwd("/code"), ) a.NoError(err, "NewSandbox error") diff --git a/extensions/e2b/unit_test.go b/extensions/e2b/unit_test.go new file mode 100644 index 0000000..311d9d2 --- /dev/null +++ b/extensions/e2b/unit_test.go @@ -0,0 +1,166 @@ +package e2b_test + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/conneroisu/groq-go/extensions/e2b" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +func getapiKey(t *testing.T) string { + apiKey := os.Getenv("E2B_API_KEY") + if apiKey == "" { + t.Fail() + } + return apiKey +} + +func TestPostSandbox(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + sb, err := e2b.NewSandbox( + ctx, + getapiKey(t), + e2b.WithLogger(test.DefaultLogger), + ) + a.NoError(err, "NewSandbox error") + lsr, err := sb.Ls(ctx, ".") + a.NoError(err) + for _, name := range []string{"boot", "code", "dev", "etc", "home"} { + a.Contains(lsr, e2b.LsResult{ + Name: name, + IsDir: true, + }) + } + err = sb.Mkdir(ctx, "heelo") + a.NoError(err) + lsr, err = sb.Ls(ctx, "/") + a.NoError(err) + a.Contains(lsr, e2b.LsResult{ + Name: "heelo", + IsDir: true, + }) +} + +// TestWriteRead tests the Write and Read methods of the Sandbox. +func TestWriteRead(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + filePath := "test.txt" + content := "Hello, world!" + a := assert.New(t) + ctx := context.Background() + sb, err := e2b.NewSandbox( + ctx, + getapiKey(t), + e2b.WithLogger(test.DefaultLogger), + ) + a.NoError(err, "NewSandbox error") + err = sb.Write(ctx, filePath, []byte(content)) + a.NoError(err, "Write error") + readContent, err := sb.Read(ctx, filePath) + a.NoError(err, "Read error") + a.Equal(content, string(readContent), "Read content does not match written content") + readBytesContent, err := sb.ReadBytes(ctx, filePath) + a.NoError(err, "ReadBytes error") + a.Equal(content, string(readBytesContent), "ReadBytes content does not match written content") + err = sb.Stop(ctx) + a.NoError(err, "Stop error") +} + +func TestCreateProcess(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + sb, err := e2b.NewSandbox( + ctx, + getapiKey(t), + e2b.WithLogger(test.DefaultLogger), + ) + a.NoError(err, "NewSandbox error") + proc, err := sb.NewProcess("echo 'Hello World!'", + e2b.Process{ + Env: map[string]string{ + "FOO": "bar", + }, + }) + a.NoError(err, "could not create process") + err = proc.Start(ctx) + a.NoError(err) + proc, err = sb.NewProcess("sleep 2 && echo 'Hello World!'", e2b.Process{}) + a.NoError(err, "could not create process") + err = proc.Start(ctx) + a.NoError(err) + stdOutEvents := make(chan e2b.Event) + err = proc.SubscribeStdout(stdOutEvents) + a.NoError(err) + event := <-stdOutEvents + jsonBytes, err := json.MarshalIndent(&event, "", " ") + if err != nil { + a.Error(err) + return + } + t.Logf("test got event: %s", string(jsonBytes)) +} + +func TestFilesystemSubscribe(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + sb, err := e2b.NewSandbox( + ctx, + getapiKey(t), + e2b.WithLogger(test.DefaultLogger), + e2b.WithCwd("/tmp"), + ) + a.NoError(err, "NewSandbox error") + // subscribe to a file + events := make(chan e2b.Event) + err = sb.Watch(ctx, "/tmp/", events) + a.NoError(err) + go func() { + for event := range events { + jsonBytes, err := json.MarshalIndent(event, "", " ") + if err != nil { + a.Error(err) + return + } + t.Logf("test got event: %s", string(jsonBytes)) + } + }() + // create a file + err = sb.Write(ctx, "/tmp/file.txt", []byte("Hello World!")) + a.NoError(err) + err = sb.Write(ctx, "/tmp/file2.txt", []byte("Hello World!")) + a.NoError(err) + time.Sleep(3 * time.Second) +} + +func TestKeepAlive(t *testing.T) { + if !test.IsUnitTest() { + t.Skip() + } + a := assert.New(t) + ctx := context.Background() + sb, err := e2b.NewSandbox( + ctx, + getapiKey(t), + e2b.WithLogger(test.DefaultLogger), + ) + a.NoError(err, "NewSandbox error") + err = sb.KeepAlive(ctx, time.Minute*2) + a.NoError(err) +} diff --git a/extensions/toolhouse/options.go b/extensions/toolhouse/options.go index 7ce9632..7755c74 100644 --- a/extensions/toolhouse/options.go +++ b/extensions/toolhouse/options.go @@ -7,28 +7,20 @@ import ( // WithBaseURL sets the base URL for the Toolhouse extension. func WithBaseURL(baseURL string) Options { - return func(e *Toolhouse) { - e.baseURL = baseURL - } + return func(e *Toolhouse) { e.baseURL = baseURL } } // WithClient sets the client for the Toolhouse extension. func WithClient(client *http.Client) Options { - return func(e *Toolhouse) { - e.client = client - } + return func(e *Toolhouse) { e.client = client } } // WithMetadata sets the metadata for the get tools request. func WithMetadata(metadata map[string]any) Options { - return func(r *Toolhouse) { - r.metadata = metadata - } + return func(r *Toolhouse) { r.metadata = metadata } } // WithLogger sets the logger for the Toolhouse extension. func WithLogger(logger *slog.Logger) Options { - return func(r *Toolhouse) { - r.logger = logger - } + return func(r *Toolhouse) { r.logger = logger } } diff --git a/extensions/toolhouse/run.go b/extensions/toolhouse/run.go index 15624f6..d3c6446 100644 --- a/extensions/toolhouse/run.go +++ b/extensions/toolhouse/run.go @@ -2,18 +2,17 @@ package toolhouse import ( "context" - "encoding/json" "fmt" - "io" "net/http" "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/tools" ) type ( request struct { - Content groq.ToolCall `json:"content,omitempty"` + Content tools.ToolCall `json:"content,omitempty"` Provider string `json:"provider"` Metadata map[string]any `json:"metadata"` Bundle string `json:"bundle"` @@ -39,19 +38,20 @@ func (e *Toolhouse) Run( ctx context.Context, response groq.ChatCompletionResponse, ) ([]groq.ChatCompletionMessage, error) { + var respH []groq.ChatCompletionMessage + var toolCall tools.ToolCall e.logger.Debug("Running Toolhouse extension", "response", response) if response.Choices[0].FinishReason != groq.FinishReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { return nil, fmt.Errorf("Not a function call") } - respH := []groq.ChatCompletionMessage{} - for _, tool := range response.Choices[0].Message.ToolCalls { + for _, toolCall = range response.Choices[0].Message.ToolCalls { req, err := builders.NewRequest( ctx, e.header, http.MethodPost, fmt.Sprintf("%s%s", e.baseURL, runToolEndpoint), builders.WithBody(request{ - Content: tool, + Content: toolCall, Provider: e.provider, Metadata: e.metadata, Bundle: e.bundle, @@ -60,18 +60,7 @@ func (e *Toolhouse) Run( if err != nil { return nil, err } - resp, err := e.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%v", resp) - } - bdy, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } + e.logger.Debug("running tool", "tool", toolCall.Function.Name, "call", toolCall.Function.Arguments) var runResp struct { Provider string `json:"provider"` Content struct { @@ -81,9 +70,9 @@ func (e *Toolhouse) Run( Content string `json:"content"` } `json:"content"` } - err = json.Unmarshal(bdy, &runResp) + err = e.sendRequest(req, &runResp) if err != nil { - return nil, fmt.Errorf("failed to unmarshal response body: %w: %s", err, string(bdy)) + return nil, err } respH = append(respH, groq.ChatCompletionMessage{ Content: runResp.Content.Content, diff --git a/extensions/toolhouse/toolhouse.go b/extensions/toolhouse/toolhouse.go index 7f509a4..d38070e 100644 --- a/extensions/toolhouse/toolhouse.go +++ b/extensions/toolhouse/toolhouse.go @@ -2,16 +2,15 @@ package toolhouse import ( + "encoding/json" "fmt" + "io" "log/slog" "net/http" - "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/pkg/builders" ) -//go:generate gomarkdoc -o README.md -e . - const ( defaultBaseURL = "https://api.toolhouse.ai/v1" getToolsEndpoint = "/get_tools" @@ -28,7 +27,6 @@ type ( provider string metadata map[string]any bundle string - tools []groq.Tool logger *slog.Logger header builders.Header } @@ -61,3 +59,44 @@ func NewExtension(apiKey string, opts ...Options) (e *Toolhouse, err error) { } return e, nil } + +func (e *Toolhouse) sendRequest(req *http.Request, v interface{}) error { + req.Header.Set("Accept", "application/json") + contentType := req.Header.Get("Content-Type") + if contentType == "" { + req.Header.Set("Content-Type", "application/json") + } + res, err := e.client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode < http.StatusOK || + res.StatusCode >= http.StatusBadRequest { + return fmt.Errorf("failed to send http request: %s", res.Status) + } + if v == nil { + return nil + } + switch o := v.(type) { + case *string: + b, err := io.ReadAll(res.Body) + if err != nil { + return err + } + *o = string(b) + return nil + default: + e.logger.Debug("decoding json response") + err = json.NewDecoder(res.Body).Decode(v) + if err != nil { + read, err := io.ReadAll(res.Body) + if err != nil { + return err + } + e.logger.Debug("failed to decode response", "response", string(read)) + return fmt.Errorf("failed to decode response: %s", string(read)) + } + return nil + } +} diff --git a/extensions/toolhouse/toolhouse_test.go b/extensions/toolhouse/toolhouse_test.go index e79bc4f..b4fd4b5 100644 --- a/extensions/toolhouse/toolhouse_test.go +++ b/extensions/toolhouse/toolhouse_test.go @@ -2,42 +2,19 @@ package toolhouse_test import ( "context" - "log/slog" "os" - "strings" "testing" "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/toolhouse" + "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" ) -var ( - defaultLogger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == "time" { - return slog.Attr{} - } - if a.Key == "level" { - return slog.Attr{} - } - if a.Key == "source" { - str := a.Value.String() - split := strings.Split(str, "/") - if len(split) > 2 { - a.Value = slog.StringValue(strings.Join(split[len(split)-2:], "/")) - } - } - return a - }})) -) - func TestNewExtension(t *testing.T) { a := assert.New(t) ctx := context.Background() - if os.Getenv("UNIT") == "" { + if !test.IsUnitTest() { t.Skip("Skipping Toolhouse extension test") } @@ -46,7 +23,7 @@ func TestNewExtension(t *testing.T) { "id": "conner", "timezone": 5, }), - toolhouse.WithLogger(defaultLogger), + toolhouse.WithLogger(test.DefaultLogger), ) a.NoError(err) client, err := groq.NewClient(os.Getenv("GROQ_KEY")) @@ -57,11 +34,12 @@ func TestNewExtension(t *testing.T) { Content: "Write a python function to print the first 10 prime numbers containing the number 3 then respond with the answer. DO NOT GUESS WHAT THE OUTPUT SHOULD BE. MAKE SURE TO CALL THE TOOL GIVEN.", }, } - print(history[len(history)-1].Content) + tooling, err := ext.GetTools(ctx) + a.NoError(err) re, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ Model: groq.ModelLlama3Groq70B8192ToolUsePreview, Messages: history, - Tools: ext.MustGetTools(ctx), + Tools: tooling, ToolChoice: "required", }) a.NoError(err) diff --git a/extensions/toolhouse/tools.go b/extensions/toolhouse/tools.go index e551884..cacaf3d 100644 --- a/extensions/toolhouse/tools.go +++ b/extensions/toolhouse/tools.go @@ -7,8 +7,8 @@ import ( "io" "net/http" - "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/tools" ) // MustGetTools returns a list of tools that the extension can use. @@ -16,7 +16,7 @@ import ( // It panics if an error occurs. func (e *Toolhouse) MustGetTools( ctx context.Context, -) []groq.Tool { +) []tools.Tool { tools, err := e.GetTools(ctx) if err != nil { panic(err) @@ -27,10 +27,7 @@ func (e *Toolhouse) MustGetTools( // GetTools returns a list of tools that the extension can use. func (e *Toolhouse) GetTools( ctx context.Context, -) ([]groq.Tool, error) { - if len(e.tools) > 0 { - return e.tools, nil - } +) ([]tools.Tool, error) { e.logger.Debug("Getting tools from Toolhouse extension") url := e.baseURL + getToolsEndpoint req, err := builders.NewRequest( @@ -60,9 +57,10 @@ func (e *Toolhouse) GetTools( if err != nil { return nil, fmt.Errorf("failed to read response body: %w: %s", err, string(bdy)) } - err = json.Unmarshal(bdy, &e.tools) + var tooling []tools.Tool + err = json.Unmarshal(bdy, &tooling) if err != nil { return nil, err } - return e.tools, nil + return tooling, nil } diff --git a/go.mod b/go.mod index 566c94f..58356d9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/groq-go -go 1.23.1 +go 1.23.2 require ( github.com/gorilla/websocket v1.5.3 diff --git a/go.work.sum b/go.work.sum index bc511bf..593da36 100644 --- a/go.work.sum +++ b/go.work.sum @@ -49,7 +49,6 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4er github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -59,10 +58,6 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/intel/goresctrl v0.3.0/go.mod h1:fdz3mD85cmP9sHD8JUlrNWAxvwM86CrbmVXltEKd7zk= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/logrusorgru/aurora/v4 v4.0.0/go.mod h1:lP0iIa2nrnT/qoFXcOZSrZQpJ1o6n2CUf/hyHi2Q4ZQ= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= @@ -87,7 +82,6 @@ github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626/go. github.com/opencontainers/selinux v1.11.0/go.mod h1:E5dMC3VPuVvVHDYmi78qvhJp8+M586T4DlDRYpFkyec= github.com/pelletier/go-toml v1.9.1/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/princjef/termdiff v0.1.0/go.mod h1:JJOfCA/eR6T1JfsoxQQ6jsG3LGoQDoKUIRQrKqAO+p4= github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= @@ -95,7 +89,6 @@ github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJ github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= @@ -133,12 +126,14 @@ golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQz golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/oauth2 v0.11.0/go.mod h1:LdF7O/8bLR/qWK9DrpXmbHLTouvRHK0SgJl0GmDBchk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= @@ -156,10 +151,9 @@ golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13/go.mod h1:CCviP9RmpZ1mxVr8MUjCnSiY09IbAXZxhLE6EhHIdPU= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/pkg/builders/requests.go b/pkg/builders/requests.go index f05ff20..71f7244 100644 --- a/pkg/builders/requests.go +++ b/pkg/builders/requests.go @@ -8,7 +8,7 @@ import ( "net/http" ) -var builder = NewRequestBuilder() +var builder RequestBuilder = &defaultRequestBuilder{} type ( // Header is an struct interface for setting common headers. diff --git a/pkg/test/encoding.go b/pkg/test/encoding.go new file mode 100644 index 0000000..a0bf80f --- /dev/null +++ b/pkg/test/encoding.go @@ -0,0 +1,27 @@ +package test + +import ( + "encoding/json" + "io" +) + +func encode(v any) []byte { + res, err := json.Marshal(v) + if err != nil { + panic(err) + } + return res +} +func decode[T any](r io.Reader) T { + // make a new instance of the type + v := new(T) + // read the JSON data from the request body + bod, err := io.ReadAll(r) + if err != nil { + panic(err) + } + if err := json.Unmarshal(bod, &v); err != nil { + panic(err) + } + return *v +} diff --git a/pkg/test/helpers.go b/pkg/test/helpers.go index c0428f5..290c07f 100644 --- a/pkg/test/helpers.go +++ b/pkg/test/helpers.go @@ -1,8 +1,11 @@ package test import ( + "fmt" + "log/slog" "net/http" "os" + "strings" "testing" ) @@ -60,3 +63,34 @@ func (t *TokenRoundTripper) RoundTrip( func IsUnitTest() bool { return os.Getenv("UNIT") != "" } + +// GetAPIKey returns the api key. +func GetAPIKey(key string) (string, error) { + apiKey := os.Getenv(key) + if apiKey == "" { + return "", fmt.Errorf("api key: %s is required", key) + } + return apiKey, nil +} + +// DefaultLogger is a default logger. +var DefaultLogger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == "time" { + return slog.Attr{} + } + if a.Key == "level" { + return slog.Attr{} + } + if a.Key == slog.SourceKey { + str := a.Value.String() + split := strings.Split(str, "/") + if len(split) > 2 { + a.Value = slog.StringValue(strings.Join(split[len(split)-2:], "/")) + a.Value = slog.StringValue(strings.Replace(a.Value.String(), "}", "", -1)) + } + } + return a + }})) diff --git a/pkg/test/mod-composio.go b/pkg/test/mod-composio.go new file mode 100644 index 0000000..4a8fac5 --- /dev/null +++ b/pkg/test/mod-composio.go @@ -0,0 +1,43 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" +) + +// ComposioTestServer Creates a mocked Composer server which can pretend to handle requests during testing. +func (ts *ServerTest) ComposioTestServer() *httptest.Server { + return httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf( + "received a %s request at path %q\n", + r.Method, + r.URL.Path, + ) + + // check auth + if r.Header.Get("X-API-Key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error( + w, + "the resource path doesn't exist", + http.StatusNotFound, + ) + }), + ) +} diff --git a/pkg/test/mod-e2b.go b/pkg/test/mod-e2b.go new file mode 100644 index 0000000..048abc2 --- /dev/null +++ b/pkg/test/mod-e2b.go @@ -0,0 +1,44 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" +) + +// E2bTestServer creates a test server for emulating the e2b api. +func (ts *ServerTest) E2bTestServer() *httptest.Server { + return httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf( + "received a %s request at path %q\n", + r.Method, + r.URL.Path, + ) + + // check auth + if r.Header.Get("X-API-Key") != GetTestToken() && + r.Header.Get("api-key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error( + w, + "the resource path doesn't exist", + http.StatusNotFound, + ) + }), + ) +} diff --git a/pkg/test/mod-groq.go b/pkg/test/mod-groq.go new file mode 100644 index 0000000..c00a7bb --- /dev/null +++ b/pkg/test/mod-groq.go @@ -0,0 +1,44 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" +) + +// GroqTestServer Creates a mocked Groq server which can pretend to handle requests during testing. +func (ts *ServerTest) GroqTestServer() *httptest.Server { + return httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf( + "received a %s request at path %q\n", + r.Method, + r.URL.Path, + ) + + // check auth + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && + r.Header.Get("api-key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error( + w, + "the resource path doesn't exist", + http.StatusNotFound, + ) + }), + ) +} diff --git a/pkg/test/mod-toolhouse.go b/pkg/test/mod-toolhouse.go new file mode 100644 index 0000000..046c6bf --- /dev/null +++ b/pkg/test/mod-toolhouse.go @@ -0,0 +1,44 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" +) + +// ToolhouseTestServer creates a test server for emulating the toolhouse api. +func (ts *ServerTest) ToolhouseTestServer() *httptest.Server { + return httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf( + "received a %s request at path %q\n", + r.Method, + r.URL.Path, + ) + + // check auth + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && + r.Header.Get("api-key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error( + w, + "the resource path doesn't exist", + http.StatusNotFound, + ) + }), + ) +} diff --git a/pkg/test/server.go b/pkg/test/server.go index c3a7843..8359422 100644 --- a/pkg/test/server.go +++ b/pkg/test/server.go @@ -1,10 +1,8 @@ package test import ( - "log" + "log/slog" "net/http" - "net/http/httptest" - "regexp" "strings" ) @@ -19,57 +17,25 @@ func GetTestToken() string { // ServerTest is a test server for the groq api. type ServerTest struct { - handlers map[string]handler + handlers map[string]Handler + logger *slog.Logger } -// handler is a function that handles a request. -type handler func(w http.ResponseWriter, r *http.Request) +// Handler is a function that handles a request. +type Handler func(w http.ResponseWriter, r *http.Request) // NewTestServer creates a new test server. func NewTestServer() *ServerTest { - return &ServerTest{handlers: make(map[string]handler)} + return &ServerTest{ + handlers: make(map[string]Handler), + logger: DefaultLogger, + } } // RegisterHandler registers a handler for a path. -func (ts *ServerTest) RegisterHandler(path string, handler handler) { +func (ts *ServerTest) RegisterHandler(path string, handler Handler) { // to make the registered paths friendlier to a regex match in the route handler // in GroqTestServer path = strings.ReplaceAll(path, "*", ".*") ts.handlers[path] = handler } - -// GroqTestServer Creates a mocked Groq server which can pretend to handle requests during testing. -func (ts *ServerTest) GroqTestServer() *httptest.Server { - return httptest.NewUnstartedServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf( - "received a %s request at path %q\n", - r.Method, - r.URL.Path, - ) - - // check auth - if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && - r.Header.Get("api-key") != GetTestToken() { - w.WriteHeader(http.StatusUnauthorized) - return - } - - // Handle /path/* routes. - // Note: the * is converted to a .* in register handler for proper regex handling - for route, handler := range ts.handlers { - // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered - pattern, _ := regexp.Compile("^" + route + "$") - if pattern.MatchString(r.URL.Path) { - handler(w, r) - return - } - } - http.Error( - w, - "the resource path doesn't exist", - http.StatusNotFound, - ) - }), - ) -} diff --git a/pkg/tools/doc.go b/pkg/tools/doc.go new file mode 100644 index 0000000..6ca658e --- /dev/null +++ b/pkg/tools/doc.go @@ -0,0 +1,2 @@ +// Package tools contains the interfaces for groq-go tooling usable by llms. +package tools diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go new file mode 100644 index 0000000..127a1a7 --- /dev/null +++ b/pkg/tools/tools.go @@ -0,0 +1,57 @@ +package tools + +const ( + ToolTypeFunction ToolType = "function" // ToolTypeFunction is the function tool type. +) + +type ( + // Tool represents the tool. + Tool struct { + Type ToolType `json:"type"` // Type is the type of the tool. + Function FunctionDefinition `json:"function,omitempty"` // Function is the tool's functional definition. + } + // ToolType is the tool type. + // + // string + ToolType string + // ToolChoice represents the tool choice. + ToolChoice struct { + Type ToolType `json:"type"` // Type is the type of the tool choice. + Function ToolFunction `json:"function,omitempty"` // Function is the function of the tool choice. + } + // ToolFunction represents the tool function. + ToolFunction struct { + Name string `json:"name"` // Name is the name of the tool function. + } + // FunctionDefinition represents the function definition. + FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters FunctionParameters `json:"parameters"` + } + // FunctionParameters represents the function parameters of a tool. + FunctionParameters struct { + Type string `json:"type"` + Properties map[string]PropertyDefinition `json:"properties"` + Required []string `json:"required"` + AdditionalProperties bool `json:"additionalProperties,omitempty"` + } + // PropertyDefinition represents the property definition. + PropertyDefinition struct { + Type string `json:"type"` + Description string `json:"description"` + } + // ToolCall represents a tool call. + ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` // Index is the index of the tool call. + ID string `json:"id"` // ID is the id of the tool call. + Type string `json:"type"` // Type is the type of the tool call. + Function FunctionCall `json:"function"` // Function is the function of the tool call. + } + // FunctionCall represents a function call. + FunctionCall struct { + Name string `json:"name,omitempty"` // Name is the name of the function call. + Arguments string `json:"arguments,omitempty"` // Arguments is the arguments of the function call in JSON format. + } +) diff --git a/scripts/generate-e2b-kernels/kernels.go.tmpl b/scripts/generate-e2b-kernels/kernels.go.tmpl deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/generate-e2b-kernels/kernels_test.go.tmpl b/scripts/generate-e2b-kernels/kernels_test.go.tmpl deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/generate-e2b-kernels/main.go b/scripts/generate-e2b-kernels/main.go deleted file mode 100644 index 06ab7d0..0000000 --- a/scripts/generate-e2b-kernels/main.go +++ /dev/null @@ -1 +0,0 @@ -package main diff --git a/scripts/generate-models/go.mod b/scripts/generate-models/go.mod index 7911f0b..004f126 100644 --- a/scripts/generate-models/go.mod +++ b/scripts/generate-models/go.mod @@ -1,12 +1,7 @@ module github.com/conneroisu/groq-go/cmd/models -go 1.23.1 +go 1.23.2 require github.com/samber/lo v1.47.0 -require ( - golang.org/x/mod v0.21.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/text v0.18.0 // indirect - golang.org/x/tools v0.25.0 // indirect -) +require golang.org/x/text v0.18.0 // indirect diff --git a/scripts/generate-models/go.sum b/scripts/generate-models/go.sum index 6123181..95a07ff 100644 --- a/scripts/generate-models/go.sum +++ b/scripts/generate-models/go.sum @@ -1,10 +1,4 @@ github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= -golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=