Skip to content

Commit

Permalink
Merge pull request #36 from conneroisu/devie
Browse files Browse the repository at this point in the history
devie
  • Loading branch information
conneroisu authored Sep 11, 2024
2 parents 5099e94 + df6c89f commit 45916bc
Show file tree
Hide file tree
Showing 64 changed files with 6,632 additions and 186 deletions.
488 changes: 440 additions & 48 deletions README.md

Large diffs are not rendered by default.

40 changes: 20 additions & 20 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@ import (
"os"
)

// CreateTranscription calls the transcriptions endpoint with the given request.
//
// Returns transcribed text in the response_format specified in the request.
func (c *Client) CreateTranscription(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, transcriptionsSuffix)
}

// CreateTranslation calls the translations endpoint with the given request.
//
// Returns the translated text in the response_format specified in the request.
func (c *Client) CreateTranslation(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, translationsSuffix)
}

// AudioRequest represents a request structure for audio API.
type AudioRequest struct {
Model Model // Model is the model to use for the transcription.
Expand Down Expand Up @@ -80,26 +100,6 @@ func (r *audioTextResponse) toAudioResponse() AudioResponse {
}
}

// CreateTranscription calls the transcriptions endpoint with the given request.
//
// Returns transcribed text in the response_format specified in the request.
func (c *Client) CreateTranscription(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, transcriptionsSuffix)
}

// CreateTranslation calls the translations endpoint with the given request.
//
// Returns the translated text in the response_format specified in the request.
func (c *Client) CreateTranslation(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, translationsSuffix)
}

// callAudioAPI calls the audio API with the given request.
//
// Currently supports both the transcription and translation APIs.
Expand Down
2 changes: 1 addition & 1 deletion audio_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"strings"
"testing"

groq "github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go/internal/test"
"github.com/stretchr/testify/assert"
)
Expand Down
152 changes: 107 additions & 45 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package groq
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
)

const (
Expand Down Expand Up @@ -85,7 +88,7 @@ type ChatCompletionMessage struct {
ToolCallID string `json:"tool_call_id,omitempty"`
}

// MarshalJSON implements the json.Marshaler interface.
// MarshalJSON method implements the json.Marshaler interface.
func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
if m.Content != "" && m.MultiContent != nil {
return nil, &ErrContentFieldsMisused{field: "Content"}
Expand Down Expand Up @@ -114,7 +117,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(msg)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
// UnmarshalJSON method implements the json.Unmarshaler interface.
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) {
msg := struct {
Role Role `json:"role"`
Expand Down Expand Up @@ -169,53 +172,53 @@ type ChatCompletionResponseFormatType string

// ChatCompletionResponseFormat is the chat completion response format.
type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"` // Type is the type of the chat completion response format.
JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` // JSONSchema is the json schema of the chat completion response format.
// Type is the type of the chat completion response format.
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
// JSONSchema is the json schema of the chat completion response format.
JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"`
}

// ChatCompletionResponseFormatJSONSchema is the chat completion response format
// json schema.
type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"` // Name is the name of the chat completion response format json schema.
Description string `json:"description,omitempty"` // Description is the description of the chat completion response format json schema.
Schema json.Marshaler `json:"schema"` // Schema is the schema of the chat completion response format json schema.
Strict bool `json:"strict"` // Strict is the strict of the chat completion response format json schema.
// Name is the name of the chat completion response format json schema.
//
// it is used to further identify the schema in the response.
Name string `json:"name"`
// response format json schema.
// Description is the description of the chat completion response format
// json schema.
Description string `json:"description,omitempty"`
// description of the chat completion response format json schema.
// Schema is the schema of the chat completion response format json schema.
Schema schema `json:"schema"`
// Strict determines whether to enforce the schema upon the generated
// content.
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for the chat completion API.
type ChatCompletionRequest struct {
Model Model `json:"model"` // Model is the model of the chat completion request.
Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request.
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens is the max tokens of the chat completion request.
Temperature float32 `json:"temperature,omitempty"` // Temperature is the temperature of the chat completion request.
TopP float32 `json:"top_p,omitempty"` // TopP is the top p of the chat completion request.
N int `json:"n,omitempty"` // N is the n of the chat completion request.
Stream bool `json:"stream,omitempty"` // Stream is the stream of the chat completion request.
Stop []string `json:"stop,omitempty"` // Stop is the stop of the chat completion request.
PresencePenalty float32 `json:"presence_penalty,omitempty"` // PresencePenalty is the presence penalty of the chat completion request.
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` // ResponseFormat is the response format of the chat completion request.
Seed *int `json:"seed,omitempty"` // Seed is the seed of the chat completion request.
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // FrequencyPenalty is the frequency penalty of the chat completion request.
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
LogitBias map[string]int `json:"logit_bias,omitempty"`
// LogProbs indicates whether to return log probabilities of the output tokens or not.
// If true, returns the log probabilities of each output token returned in the content of message.
// This option is currently not available on the gpt-4-vision-preview model.
LogProbs bool `json:"logprobs,omitempty"`
// TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each
// token position, each with an associated log probability.
// logprobs must be set to true if this parameter is used.
TopLogProbs int `json:"top_logprobs,omitempty"`
User string `json:"user,omitempty"`
Tools []Tool `json:"tools,omitempty"`
// This can be either a string or an ToolChoice object.
ToolChoice any `json:"tool_choice,omitempty"`
// Options for streaming response. Only set this when you set stream: true.
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
// Disable the default behavior of parallel tool calls by setting it: false.
ParallelToolCalls any `json:"parallel_tool_calls,omitempty"`
Model Model `json:"model"` // Model is the model of the chat completion request.
Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model.
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens is the max tokens of the chat completion request.
Temperature float32 `json:"temperature,omitempty"` // Temperature is the temperature of the chat completion request.
TopP float32 `json:"top_p,omitempty"` // TopP is the top p of the chat completion request.
N int `json:"n,omitempty"` // N is the n of the chat completion request.
Stream bool `json:"stream,omitempty"` // Stream is the stream of the chat completion request.
Stop []string `json:"stop,omitempty"` // Stop is the stop of the chat completion request.
PresencePenalty float32 `json:"presence_penalty,omitempty"` // PresencePenalty is the presence penalty of the chat completion request.
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` // ResponseFormat is the response format of the chat completion request.
Seed *int `json:"seed,omitempty"` // Seed is the seed of the chat completion request.
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // FrequencyPenalty is the frequency penalty of the chat completion request.
LogitBias map[string]int `json:"logit_bias,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. incorrect: `"logit_bias":{ "You": 6}`, correct: `"logit_bias":{"1639": 6}` refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
LogProbs bool `json:"logprobs,omitempty"` // LogProbs indicates whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview model.
TopLogProbs int `json:"top_logprobs,omitempty"` // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used.
User string `json:"user,omitempty"` // User is the user of the chat completion request.
Tools []Tool `json:"tools,omitempty"` // Tools is the tools of the chat completion request.
ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or an ToolChoice object.
StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Options for streaming response. Only set this when you set stream: true.
ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false.
}

// StreamOptions represents the stream options.
Expand Down Expand Up @@ -298,7 +301,8 @@ func (r FinishReason) MarshalJSON() ([]byte, error) {

// ChatCompletionChoice represents the chat completion choice.
type ChatCompletionChoice struct {
Index int `json:"index"` // Index is the index of the choice.
Index int `json:"index"` // Index is the index of the choice.
// Message is the chat completion message of the choice.
Message ChatCompletionMessage `json:"message"` // Message is the chat completion message of the choice.
// FinishReason is the finish reason of the choice.
//
Expand All @@ -308,8 +312,11 @@ type ChatCompletionChoice struct {
// function_call: The model decided to call a function
// content_filter: Omitted content due to a flag from our content filters
// null: API response still in progress or incomplete
FinishReason FinishReason `json:"finish_reason"` // FinishReason is the finish reason of the choice.
LogProbs *LogProbs `json:"logprobs,omitempty"` // LogProbs is the log probs of the choice.
FinishReason FinishReason `json:"finish_reason"` // FinishReason is the finish reason of the choice.
// LogProbs is the log probs of the choice.
//
// This is basically the probability of the model choosing the token.
LogProbs *LogProbs `json:"logprobs,omitempty"` // LogProbs is the log probs of the choice.
}

// ChatCompletionResponse represents a response structure for chat completion API.
Expand All @@ -330,7 +337,7 @@ func (r *ChatCompletionResponse) SetHeader(h http.Header) {
r.Header = h
}

// CreateChatCompletion is an API call to create a chat completion.
// CreateChatCompletion method is an API call to create a chat completion.
func (c *Client) CreateChatCompletion(
ctx context.Context,
request ChatCompletionRequest,
Expand Down Expand Up @@ -403,7 +410,7 @@ type ChatCompletionStream struct {
*streamReader[ChatCompletionStreamResponse]
}

// CreateChatCompletionStream is an API call to create a chat completion w/ streaming
// CreateChatCompletionStream method is an API call to create a chat completion w/ streaming
// support.
//
// If set, tokens will be sent as data-only server-sent events as they become
Expand Down Expand Up @@ -437,3 +444,58 @@ func (c *Client) CreateChatCompletionStream(
}
return
}

// CreateChatCompletionJSON method is an API call to create a chat completion w/ object output.
func (c *Client) CreateChatCompletionJSON(
ctx context.Context,
request ChatCompletionRequest,
output any,
) (err error) {
r := &reflector{}
schema := r.ReflectFromType(reflect.TypeOf(output))
request.ResponseFormat = &ChatCompletionResponseFormat{}
request.ResponseFormat.JSONSchema = &ChatCompletionResponseFormatJSONSchema{
Name: schema.Title,
Description: schema.Description,
Schema: *schema,
Strict: true,
}
if !endpointSupportsModel(chatCompletionsSuffix, request.Model) {
err = ErrChatCompletionInvalidModel{
Model: request.Model,
Endpoint: chatCompletionsSuffix,
}
return
}
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(chatCompletionsSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
var response ChatCompletionResponse
err = c.sendRequest(req, &response)
if err != nil {
return
}
content := response.Choices[0].Message.Content
split := strings.Split(content, "```")
if len(split) > 1 {
content = split[1]
}
err = json.Unmarshal(
[]byte(content),
&output,
)
if err != nil {
return fmt.Errorf(
"error unmarshalling response (%s) to output: %v",
response.Choices[0].Message.Content,
err,
)
}
return
}
12 changes: 12 additions & 0 deletions examples/json-chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# json-chat

This is an example of using groq-go to create a chat application using the the ChatCompletionsJson Method.

## Usage

Make sure you have a groq key set in the environment variable `GROQ_KEY`.

```bash
export GROQ_KEY=your-groq-key
go run .
```
57 changes: 57 additions & 0 deletions examples/json-chat/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Package main demonstrates an example application of groq-go.
// It shows how to use groq-go to create a chat completion of a json object
// using the llama-3.1-70B-8192-tool-use-preview model.
package main

import (
"context"
"encoding/json"
"fmt"
"os"

"github.com/conneroisu/groq-go"
)

func main() {
if err := run(context.Background()); err != nil {
fmt.Println(err)
os.Exit(1)
}
}

// Responses is a response from the models endpoint.
type Responses []struct {
Title string `json:"title" jsonschema:"title=Poem Title,description=Title of the poem, minLength=1, maxLength=20"`
Text string `json:"text" jsonschema:"title=Poem Text,description=Text of the poem, minLength=10, maxLength=200"`
}

func run(
ctx context.Context,
) error {
client, err := groq.NewClient(os.Getenv("GROQ_KEY"))
if err != nil {
return err
}
resp := &Responses{}
err = client.CreateChatCompletionJSON(ctx, groq.ChatCompletionRequest{
Model: groq.Llama3Groq70B8192ToolUsePreview,
Messages: []groq.ChatCompletionMessage{
{
Role: groq.ChatMessageRoleUser,
Content: "Create 5 short poems in json format with title and text.",
},
},
MaxTokens: 2000,
}, resp)
if err != nil {
return err
}

jsValue, err := json.MarshalIndent(resp, "", " ")
if err != nil {
return err
}
fmt.Println(string(jsValue))

return nil
}
14 changes: 14 additions & 0 deletions examples/llava-blind/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# llava-blind

This is an example of using groq-go to create a chat completion using the llava-v1.5-7b-4096-preview model.

## Usage

Make sure you have a groq key set in the environment variable `GROQ_KEY`.

Also make sure that you are in the same directory as the `main.go` file.

```bash
export GROQ_KEY=your-groq-key
go run .
```
10 changes: 10 additions & 0 deletions examples/terminal-chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# terminal-chat

This is a simple terminal chat application using the groq-go library.

## Usage

```bash
export GROQ_KEY=your-groq-key
go run .
```
Loading

0 comments on commit 45916bc

Please sign in to comment.