diff --git a/.go-version b/.go-version index 49e0a31..14bee92 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.23.1 +1.23.2 diff --git a/LICENSE b/LICENSE index 6ce3351..ef10cc4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 connero +Copyright (c) 2024 groq-go authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 4bb9cc2..61842ce 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ - JSON Schema Generation from structs. - Supports [Toolhouse](https://app.toolhouse.ai/) function calling. [Extention](https://github.com/conneroisu/groq-go/tree/main/extensions/toolhouse) - Supports [E2b](https://e2b.dev/) function calling. [Extention](https://github.com/conneroisu/groq-go/tree/main/extensions/e2b) +- Supports [Composio](https://composio.dev/) function calling. [Extention](https://github.com/conneroisu/groq-go/tree/main/extensions/composio) +- Supports [Jigsaw Stack](https://jigsawstack.com/) function calling. [Extention](https://github.com/conneroisu/groq-go/tree/main/extensions/jigsawstack) ## Installation @@ -91,26 +93,16 @@ With specially designed hardware, the Groq API is a super fast way to query open API Documentation: https://console.groq.com/docs/quickstart -Code generated by groq\-modeler DO NOT EDIT. - -Created at: 2024\-10\-26 10:01:35 - -groq\-modeler Version 1.1.2 - ## Index - [Constants](<#constants>) -- [Variables](<#variables>) - [func AudioMultipartForm\(request AudioRequest, b builders.FormBuilder\) error](<#AudioMultipartForm>) -- [func CreateFileField\(request AudioRequest, b builders.FormBuilder\) \(err error\)](<#CreateFileField>) -- [type APIError](<#APIError>) - - [func \(e \*APIError\) Error\(\) string](<#APIError.Error>) - - [func \(e \*APIError\) UnmarshalJSON\(data \[\]byte\) \(err error\)](<#APIError.UnmarshalJSON>) -- [type AudioModel](<#AudioModel>) +- [type Agent](<#Agent>) + - [func NewAgent\(client \*Client, logger \*slog.Logger\) \*Agent](<#NewAgent>) +- [type Agenter](<#Agenter>) - [type AudioRequest](<#AudioRequest>) - [type AudioResponse](<#AudioResponse>) - [func \(r \*AudioResponse\) SetHeader\(header http.Header\)](<#AudioResponse.SetHeader>) -- [type AudioResponseFormat](<#AudioResponseFormat>) - [type ChatCompletionChoice](<#ChatCompletionChoice>) - [type ChatCompletionMessage](<#ChatCompletionMessage>) - [func \(m ChatCompletionMessage\) MarshalJSON\(\) \(\[\]byte, error\)](<#ChatCompletionMessage.MarshalJSON>) @@ -120,7 +112,6 @@ groq\-modeler Version 1.1.2 - [func \(r \*ChatCompletionResponse\) SetHeader\(h http.Header\)](<#ChatCompletionResponse.SetHeader>) - [type ChatCompletionResponseFormat](<#ChatCompletionResponseFormat>) - [type ChatCompletionResponseFormatJSONSchema](<#ChatCompletionResponseFormatJSONSchema>) -- [type ChatCompletionResponseFormatType](<#ChatCompletionResponseFormatType>) - [type ChatCompletionStream](<#ChatCompletionStream>) - [type ChatCompletionStreamChoice](<#ChatCompletionStreamChoice>) - [type ChatCompletionStreamChoiceDelta](<#ChatCompletionStreamChoiceDelta>) @@ -128,7 +119,6 @@ groq\-modeler Version 1.1.2 - [type ChatMessageImageURL](<#ChatMessageImageURL>) - [type ChatMessagePart](<#ChatMessagePart>) - [type ChatMessagePartType](<#ChatMessagePartType>) -- [type ChatModel](<#ChatModel>) - [type Client](<#Client>) - [func NewClient\(groqAPIKey string, opts ...Opts\) \(\*Client, error\)](<#NewClient>) - [func \(c \*Client\) CreateChatCompletion\(ctx context.Context, request ChatCompletionRequest\) \(response ChatCompletionResponse, err error\)](<#Client.CreateChatCompletion>) @@ -136,25 +126,14 @@ groq\-modeler Version 1.1.2 - [func \(c \*Client\) CreateChatCompletionStream\(ctx context.Context, request ChatCompletionRequest\) \(stream \*ChatCompletionStream, err error\)](<#Client.CreateChatCompletionStream>) - [func \(c \*Client\) CreateTranscription\(ctx context.Context, request AudioRequest\) \(AudioResponse, error\)](<#Client.CreateTranscription>) - [func \(c \*Client\) CreateTranslation\(ctx context.Context, request AudioRequest\) \(AudioResponse, error\)](<#Client.CreateTranslation>) - - [func \(c \*Client\) Moderate\(ctx context.Context, request ModerationRequest\) \(response Moderation, err error\)](<#Client.Moderate>) - - [func \(c \*Client\) MustCreateChatCompletion\(ctx context.Context, request ChatCompletionRequest\) \(response ChatCompletionResponse\)](<#Client.MustCreateChatCompletion>) -- [type DefaultErrorAccumulator](<#DefaultErrorAccumulator>) - - [func \(e \*DefaultErrorAccumulator\) Bytes\(\) \(errBytes \[\]byte\)](<#DefaultErrorAccumulator.Bytes>) - - [func \(e \*DefaultErrorAccumulator\) Write\(p \[\]byte\) error](<#DefaultErrorAccumulator.Write>) + - [func \(c \*Client\) Moderate\(ctx context.Context, messages \[\]ChatCompletionMessage, model models.ModerationModel\) \(response Moderation, err error\)](<#Client.Moderate>) - [type Endpoint](<#Endpoint>) -- [type ErrContentFieldsMisused](<#ErrContentFieldsMisused>) - - [func \(e ErrContentFieldsMisused\) Error\(\) string](<#ErrContentFieldsMisused.Error>) -- [type ErrTooManyEmptyStreamMessages](<#ErrTooManyEmptyStreamMessages>) - - [func \(e ErrTooManyEmptyStreamMessages\) Error\(\) string](<#ErrTooManyEmptyStreamMessages.Error>) - [type FinishReason](<#FinishReason>) - [func \(r FinishReason\) MarshalJSON\(\) \(\[\]byte, error\)](<#FinishReason.MarshalJSON>) - [type Format](<#Format>) -- [type HarmfulCategory](<#HarmfulCategory>) - [type ImageURLDetail](<#ImageURLDetail>) - [type LogProbs](<#LogProbs>) - [type Moderation](<#Moderation>) -- [type ModerationModel](<#ModerationModel>) -- [type ModerationRequest](<#ModerationRequest>) - [type Opts](<#Opts>) - [func WithBaseURL\(baseURL string\) Opts](<#WithBaseURL>) - [func WithClient\(client \*http.Client\) Opts](<#WithClient>) @@ -167,6 +146,9 @@ groq\-modeler Version 1.1.2 - [type Role](<#Role>) - [type Segments](<#Segments>) - [type StreamOptions](<#StreamOptions>) +- [type ToolGetter](<#ToolGetter>) +- [type ToolManager](<#ToolManager>) +- [type ToolRunner](<#ToolRunner>) - [type TopLogProbs](<#TopLogProbs>) - [type TranscriptionTimestampGranularity](<#TranscriptionTimestampGranularity>) - [type Usage](<#Usage>) @@ -175,253 +157,37 @@ groq\-modeler Version 1.1.2 ## Constants - - -```go -const ( - AudioResponseFormatJSON AudioResponseFormat = "json" // AudioResponseFormatJSON is the JSON format of some audio. - AudioResponseFormatText AudioResponseFormat = "text" // AudioResponseFormatText is the text format of some audio. - AudioResponseFormatSRT AudioResponseFormat = "srt" // AudioResponseFormatSRT is the SRT format of some audio. - AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" // AudioResponseFormatVerboseJSON is the verbose JSON format of some audio. - AudioResponseFormatVTT AudioResponseFormat = "vtt" // AudioResponseFormatVTT is the VTT format of some audio. - - TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" // TranscriptionTimestampGranularityWord is the word timestamp granularity. - TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity. -) -``` - ```go const ( - ChatMessageRoleSystem Role = "system" // ChatMessageRoleSystem is the system chat message role. - ChatMessageRoleUser Role = "user" // ChatMessageRoleUser is the user chat message role. - ChatMessageRoleAssistant Role = "assistant" // ChatMessageRoleAssistant is the assistant chat message role. - ChatMessageRoleFunction Role = "function" // ChatMessageRoleFunction is the function chat message role. - ChatMessageRoleTool Role = "tool" // ChatMessageRoleTool is the tool chat message role. - ImageURLDetailHigh ImageURLDetail = "high" // ImageURLDetailHigh is the high image url detail. - ImageURLDetailLow ImageURLDetail = "low" // ImageURLDetailLow is the low image url detail. - ImageURLDetailAuto ImageURLDetail = "auto" // ImageURLDetailAuto is the auto image url detail. - ChatMessagePartTypeText ChatMessagePartType = "text" // ChatMessagePartTypeText is the text chat message part type. - ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" // ChatMessagePartTypeImageURL is the image url chat message part type. - ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" // ChatCompletionResponseFormatTypeJSONObject is the json object chat completion response format type. - ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" // ChatCompletionResponseFormatTypeJSONSchema is the json schema chat completion response format type. - ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" // ChatCompletionResponseFormatTypeText is the text chat completion response format type. - FinishReasonStop FinishReason = "stop" // FinishReasonStop is the stop finish reason. - FinishReasonLength FinishReason = "length" // FinishReasonLength is the length finish reason. - FinishReasonFunctionCall FinishReason = "function_call" // FinishReasonFunctionCall is the function call finish reason. - FinishReasonToolCalls FinishReason = "tool_calls" // FinishReasonToolCalls is the tool calls finish reason. - FinishReasonContentFilter FinishReason = "content_filter" // FinishReasonContentFilter is the content filter finish reason. - FinishReasonNull FinishReason = "null" // FinishReasonNull is the null finish reason. -) -``` - -## Variables - - - -```go -var ( - // ModelGemma29BIt is an AI text chat model. - // - // It is created/provided by Google. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelGemma29BIt ChatModel = "gemma2-9b-it" - // ModelGemma7BIt is an AI text chat model. - // - // It is created/provided by Google. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelGemma7BIt ChatModel = "gemma-7b-it" - // ModelLlama3170BVersatile is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 32768 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3170BVersatile ChatModel = "llama-3.1-70b-versatile" - // ModelLlama318BInstant is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 131072 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama318BInstant ChatModel = "llama-3.1-8b-instant" - // ModelLlama3211BTextPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3211BTextPreview ChatModel = "llama-3.2-11b-text-preview" - // ModelLlama3211BVisionPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3211BVisionPreview ChatModel = "llama-3.2-11b-vision-preview" - // ModelLlama321BPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama321BPreview ChatModel = "llama-3.2-1b-preview" - // ModelLlama323BPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama323BPreview ChatModel = "llama-3.2-3b-preview" - // ModelLlama3290BTextPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3290BTextPreview ChatModel = "llama-3.2-90b-text-preview" - // ModelLlama3290BVisionPreview is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3290BVisionPreview ChatModel = "llama-3.2-90b-vision-preview" - // ModelLlama370B8192 is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama370B8192 ChatModel = "llama3-70b-8192" - // ModelLlama38B8192 is an AI text chat model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama38B8192 ChatModel = "llama3-8b-8192" - // ModelLlama3Groq70B8192ToolUsePreview is an AI text chat model. - // - // It is created/provided by Groq. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3Groq70B8192ToolUsePreview ChatModel = "llama3-groq-70b-8192-tool-use-preview" - // ModelLlama3Groq8B8192ToolUsePreview is an AI text chat model. - // - // It is created/provided by Groq. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlama3Groq8B8192ToolUsePreview ChatModel = "llama3-groq-8b-8192-tool-use-preview" - // ModelLlavaV157B4096Preview is an AI text chat model. - // - // It is created/provided by Other. - // - // It has 4096 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelLlavaV157B4096Preview ChatModel = "llava-v1.5-7b-4096-preview" - // ModelMixtral8X7B32768 is an AI text chat model. - // - // It is created/provided by Mistral AI. - // - // It has 32768 context window. - // - // It can be used with the following client methods: - // - CreateChatCompletion - // - CreateChatCompletionStream - // - CreateChatCompletionJSON - ModelMixtral8X7B32768 ChatModel = "mixtral-8x7b-32768" - // ModelWhisperLargeV3 is an AI audio transcription model. - // - // It is created/provided by OpenAI. - // - // It has 448 context window. - // - // It can be used with the following client methods: - // - CreateTranscription - // - CreateTranslation - ModelWhisperLargeV3 AudioModel = "whisper-large-v3" - // ModelLlamaGuard38B is an AI moderation model. - // - // It is created/provided by Meta. - // - // It has 8192 context window. - // - // It can be used with the following client methods: - // - Moderate - ModelLlamaGuard38B ModerationModel = "llama-guard-3-8b" + // ChatMessageRoleSystem is the system chat message role. + ChatMessageRoleSystem Role = "system" + // ChatMessageRoleUser is the user chat message role. + ChatMessageRoleUser Role = "user" + // ChatMessageRoleAssistant is the assistant chat message role. + ChatMessageRoleAssistant Role = "assistant" + // ChatMessageRoleFunction is the function chat message role. + ChatMessageRoleFunction Role = "function" + // ChatMessageRoleTool is the tool chat message role. + ChatMessageRoleTool Role = "tool" + + // ImageURLDetailHigh is the high image url detail. + ImageURLDetailHigh ImageURLDetail = "high" + // ImageURLDetailLow is the low image url detail. + ImageURLDetailLow ImageURLDetail = "low" + // ImageURLDetailAuto is the auto image url detail. + ImageURLDetailAuto ImageURLDetail = "auto" + + // ChatMessagePartTypeText is the text chat message part type. + ChatMessagePartTypeText ChatMessagePartType = "text" + // ChatMessagePartTypeImageURL is the image url chat message part type. + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" ) ``` -## func [AudioMultipartForm]() +## func [AudioMultipartForm]() ```go func AudioMultipartForm(request AudioRequest, b builders.FormBuilder) error @@ -429,94 +195,90 @@ func AudioMultipartForm(request AudioRequest, b builders.FormBuilder) error AudioMultipartForm creates a form with audio file contents and the name of the model to use for audio processing. - -## func [CreateFileField]() - -```go -func CreateFileField(request AudioRequest, b builders.FormBuilder) (err error) -``` - -CreateFileField creates the "file" form field from either an existing file or by using the reader. - - -## type [APIError]() + +## type [Agent]() -APIError provides error information returned by the Groq API. +Agent is an agent. ```go -type APIError struct { - Code any `json:"code,omitempty"` // Code is the code of the error. - Message string `json:"message"` // Message is the message of the error. - Param *string `json:"param,omitempty"` // Param is the param of the error. - Type string `json:"type"` // Type is the type of the error. - HTTPStatusCode int `json:"-"` // HTTPStatusCode is the status code of the error. +type Agent struct { + // contains filtered or unexported fields } ``` - -### func \(\*APIError\) [Error]() - -```go -func (e *APIError) Error() string -``` - -Error method implements the error interface on APIError. - - -### func \(\*APIError\) [UnmarshalJSON]() + +### func [NewAgent]() ```go -func (e *APIError) UnmarshalJSON(data []byte) (err error) +func NewAgent(client *Client, logger *slog.Logger) *Agent ``` -UnmarshalJSON implements the json.Unmarshaler interface. +NewAgent creates a new agent. - -## type [AudioModel]() + +## type [Agenter]() -AudioModel is the type for audio models present on the groq api. +Agenter is an interface for an agent. ```go -type AudioModel model +type Agenter interface { + ToolManager +} ``` -## type [AudioRequest]() +## type [AudioRequest]() AudioRequest represents a request structure for audio API. ```go type AudioRequest struct { - Model AudioModel // Model is the model to use for the transcription. - FilePath string // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. - Reader io.Reader // Reader is an optional io.Reader when you do not want to use an existing file. - Prompt string // Prompt is the prompt for the transcription. - Temperature float32 // Temperature is the temperature for the transcription. - Language string // Language is the language for the transcription. Only for transcription. - Format AudioResponseFormat // Format is the format for the response. + // Model is the model to use for the transcription. + Model models.AudioModel + // FilePath is either an existing file in your filesystem or a + // filename representing the contents of Reader. + FilePath string + // Reader is an optional io.Reader when you do not want to use + // an existing file. + Reader io.Reader + // Prompt is the prompt for the transcription. + Prompt string + // Temperature is the temperature for the transcription. + Temperature float32 + // Language is the language for the transcription. Only for + // transcription. + Language string + // Format is the format for the response. + Format Format } ``` -## type [AudioResponse]() +## type [AudioResponse]() AudioResponse represents a response structure for audio API. ```go type AudioResponse struct { - Task string `json:"task"` // Task is the task of the response. - Language string `json:"language"` // Language is the language of the response. - Duration float64 `json:"duration"` // Duration is the duration of the response. - Segments Segments `json:"segments"` // Segments is the segments of the response. - Words Words `json:"words"` // Words is the words of the response. - Text string `json:"text"` // Text is the text of the response. + // Task is the task of the response. + Task string `json:"task"` + // Language is the language of the response. + Language string `json:"language"` + // Duration is the duration of the response. + Duration float64 `json:"duration"` + // Segments is the segments of the response. + Segments Segments `json:"segments"` + // Words is the words of the response. + Words Words `json:"words"` + // Text is the text of the response. + Text string `json:"text"` Header http.Header // Header is the header of the response. } ``` -### func \(\*AudioResponse\) [SetHeader]() +### func \(\*AudioResponse\) [SetHeader]() ```go func (r *AudioResponse) SetHeader(header http.Header) @@ -524,21 +286,8 @@ func (r *AudioResponse) SetHeader(header http.Header) SetHeader sets the header of the response. - -## type [AudioResponseFormat]() - -AudioResponseFormat is the response format for the audio API. - -Response formatted using AudioResponseFormatJSON by default. - -string - -```go -type AudioResponseFormat string -``` - -## type [ChatCompletionChoice]() +## type [ChatCompletionChoice]() ChatCompletionChoice represents the chat completion choice. @@ -546,42 +295,49 @@ ChatCompletionChoice represents the chat completion choice. type ChatCompletionChoice struct { Index int `json:"index"` // Index is the index of the choice. // Message is the chat completion message of the choice. - Message ChatCompletionMessage `json:"message"` // Message is the chat completion message of the choice. + Message ChatCompletionMessage `json:"message"` // FinishReason is the finish reason of the choice. - // - // stop: API returned complete message, - // or a message terminated by one of the stop sequences provided via the stop parameter - // length: Incomplete model output due to max_tokens parameter or token limit - // function_call: The model decided to call a function - // content_filter: Omitted content due to a flag from our content filters - // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` // FinishReason is the finish reason of the choice. + FinishReason FinishReason `json:"finish_reason"` // LogProbs is the log probs of the choice. // - // This is basically the probability of the model choosing the token. - LogProbs *LogProbs `json:"logprobs,omitempty"` // LogProbs is the log probs of the choice. + // This is basically the probability of the model choosing the + // token. + LogProbs *LogProbs `json:"logprobs,omitempty"` } ``` -## type [ChatCompletionMessage]() +## type [ChatCompletionMessage]() ChatCompletionMessage represents the chat completion message. ```go type ChatCompletionMessage struct { - Name string `json:"name"` // Name is the name of the chat completion message. - Role Role `json:"role"` // Role is the role of the chat completion message. - Content string `json:"content"` // Content is the content of the chat completion message. - MultiContent []ChatMessagePart `json:"-"` // MultiContent is the multi content of the chat completion message. - FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` // FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model. - ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` // ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. - ToolCallID string `json:"tool_call_id,omitempty"` // ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + // Name is the name of the chat completion message. + Name string `json:"name"` + // Role is the role of the chat completion message. + Role Role `json:"role"` + // Content is the content of the chat completion message. + Content string `json:"content"` + // MultiContent is the multi content of the chat completion + // message. + MultiContent []ChatMessagePart `json:"-"` + // FunctionCall setting for Role=assistant prompts this may be + // set to the function call generated by the model. + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + // ToolCalls setting for Role=assistant prompts this may be set + // to the tool calls generated by the model, such as function + // calls. + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + // ToolCallID is setting for Role=tool prompts this should be + // set to the ID given in the assistant's prior request to call + // a tool. + ToolCallID string `json:"tool_call_id,omitempty"` } ``` -### func \(ChatCompletionMessage\) [MarshalJSON]() +### func \(ChatCompletionMessage\) [MarshalJSON]() ```go func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) @@ -590,7 +346,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) MarshalJSON method implements the json.Marshaler interface. -### func \(\*ChatCompletionMessage\) [UnmarshalJSON]() +### func \(\*ChatCompletionMessage\) [UnmarshalJSON]() ```go func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) @@ -599,56 +355,100 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) UnmarshalJSON method implements the json.Unmarshaler interface. -## type [ChatCompletionRequest]() +## type [ChatCompletionRequest]() ChatCompletionRequest represents a request structure for the chat completion API. ```go type ChatCompletionRequest struct { - Model ChatModel `json:"model"` // Model is the model of the chat completion request. - Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model. - MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens is the max tokens of the chat completion request. - Temperature float32 `json:"temperature,omitempty"` // Temperature is the temperature of the chat completion request. - TopP float32 `json:"top_p,omitempty"` // TopP is the top p of the chat completion request. - N int `json:"n,omitempty"` // N is the n of the chat completion request. - Stream bool `json:"stream,omitempty"` // Stream is the stream of the chat completion request. - Stop []string `json:"stop,omitempty"` // Stop is the stop of the chat completion request. - PresencePenalty float32 `json:"presence_penalty,omitempty"` // PresencePenalty is the presence penalty of the chat completion request. - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` // ResponseFormat is the response format of the chat completion request. - Seed *int `json:"seed,omitempty"` // Seed is the seed of the chat completion request. - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // FrequencyPenalty is the frequency penalty of the chat completion request. - LogitBias map[string]int `json:"logit_bias,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. incorrect: `"logit_bias":{ "You": 6}`, correct: `"logit_bias":{"1639": 6}` refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias - LogProbs bool `json:"logprobs,omitempty"` // LogProbs indicates whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview model. - TopLogProbs int `json:"top_logprobs,omitempty"` // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - User string `json:"user,omitempty"` // User is the user of the chat completion request. - Tools []tools.Tool `json:"tools,omitempty"` // Tools is the tools of the chat completion request. - ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or an ToolChoice object. - StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Options for streaming response. Only set this when you set stream: true. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. - RetryDelay time.Duration `json:"-"` // RetryDelay is the delay between retries. + // Model is the model of the chat completion request. + Model models.ChatModel `json:"model"` + // Messages is the messages of the chat completion request. + // + // These act as the prompt for the model. + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens is the max tokens of the chat completion request. + MaxTokens int `json:"max_tokens,omitempty"` + // Temperature is the temperature of the chat completion + // request. + Temperature float32 `json:"temperature,omitempty"` + // TopP is the top p of the chat completion request. + TopP float32 `json:"top_p,omitempty"` + // N is the n of the chat completion request. + N int `json:"n,omitempty"` + // Stream is the stream of the chat completion request. + Stream bool `json:"stream,omitempty"` + // Stop is the stop of the chat completion request. + Stop []string `json:"stop,omitempty"` + // PresencePenalty is the presence penalty of the chat + // completion request. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + // ResponseFormat is the response format of the chat completion + // request. + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + // Seed is the seed of the chat completion request. + Seed *int `json:"seed,omitempty"` + // FrequencyPenalty is the frequency penalty of the chat + // completion request. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // LogitBias is must be a token id string (specified by their + // token ID in the tokenizer), not a word string. incorrect: `"logit_bias":{ "You": 6}`, correct: `"logit_bias":{"1639": 6}` refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + // LogProbs indicates whether to return log probabilities of the + // output tokens or not. If true, returns the log probabilities + // of each output token returned in the content of message. + // + // This option is currently not available on the + // gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the + // number of most likely tokens to return at each token + // position, each with an associated log probability. Logprobs + // must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + // User is the user of the chat completion request. + User string `json:"user,omitempty"` + // Tools is the tools of the chat completion request. + Tools []tools.Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // RetryDelay is the delay between retries. + RetryDelay time.Duration `json:"-"` } ``` -## type [ChatCompletionResponse]() +## type [ChatCompletionResponse]() ChatCompletionResponse represents a response structure for chat completion API. ```go type ChatCompletionResponse struct { - ID string `json:"id"` // ID is the id of the response. - Object string `json:"object"` // Object is the object of the response. - Created int64 `json:"created"` // Created is the created time of the response. - Model ChatModel `json:"model"` // Model is the model of the response. - Choices []ChatCompletionChoice `json:"choices"` // Choices is the choices of the response. - Usage Usage `json:"usage"` // Usage is the usage of the response. - SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint of the response. - http.Header // Header is the header of the response. + // ID is the id of the response. + ID string `json:"id"` + // Object is the object of the response. + Object string `json:"object"` + // Created is the created time of the response. + Created int64 `json:"created"` + // Model is the model of the response. + Model models.ChatModel `json:"model"` + // Choices is the choices of the response. + Choices []ChatCompletionChoice `json:"choices"` + // Usage is the usage of the response. + Usage Usage `json:"usage"` + // SystemFingerprint is the system fingerprint of the response. + SystemFingerprint string `json:"system_fingerprint"` + // Header is the header of the response. + http.Header } ``` -### func \(\*ChatCompletionResponse\) [SetHeader]() +### func \(\*ChatCompletionResponse\) [SetHeader]() ```go func (r *ChatCompletionResponse) SetHeader(h http.Header) @@ -657,21 +457,22 @@ func (r *ChatCompletionResponse) SetHeader(h http.Header) SetHeader sets the header of the response. -## type [ChatCompletionResponseFormat]() +## type [ChatCompletionResponseFormat]() ChatCompletionResponseFormat is the chat completion response format. ```go type ChatCompletionResponseFormat struct { // Type is the type of the chat completion response format. - Type ChatCompletionResponseFormatType `json:"type,omitempty"` - // JSONSchema is the json schema of the chat completion response format. + Type Format `json:"type,omitempty"` + // JSONSchema is the json schema of the chat completion response + // format. JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` } ``` -## type [ChatCompletionResponseFormatJSONSchema]() +## type [ChatCompletionResponseFormatJSONSchema]() ChatCompletionResponseFormatJSONSchema is the chat completion response format json schema. @@ -682,37 +483,23 @@ type ChatCompletionResponseFormatJSONSchema struct { // // it is used to further identify the schema in the response. Name string `json:"name"` - // Description is the description of the chat completion response - // format - // json schema. + // Description is the description of the chat completion + // response format json schema. Description string `json:"description,omitempty"` - // description of the chat completion response format json schema. - // Schema is the schema of the chat completion response format json schema. + // Schema is the schema of the chat completion response format + // json schema. Schema schema.Schema `json:"schema"` - // Strict determines whether to enforce the schema upon the generated - // content. + // Strict determines whether to enforce the schema upon the + // generated content. Strict bool `json:"strict"` } ``` - -## type [ChatCompletionResponseFormatType]() - -ChatCompletionResponseFormatType is the chat completion response format type. - -string - -```go -type ChatCompletionResponseFormatType string -``` - -## type [ChatCompletionStream]() +## type [ChatCompletionStream]() ChatCompletionStream is a stream of ChatCompletionStreamResponse. -Note: Perhaps it is more elegant to abstract Stream using generics. - ```go type ChatCompletionStream struct { // contains filtered or unexported fields @@ -720,84 +507,113 @@ type ChatCompletionStream struct { ``` -## type [ChatCompletionStreamChoice]() +## type [ChatCompletionStreamChoice]() ChatCompletionStreamChoice represents a response structure for chat completion API. ```go type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` + // Index is the index of the choice. + Index int `json:"index"` + // Delta is the delta of the choice. + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + // FinishReason is the finish reason of the choice. + FinishReason FinishReason `json:"finish_reason"` } ``` -## type [ChatCompletionStreamChoiceDelta]() +## type [ChatCompletionStreamChoiceDelta]() ChatCompletionStreamChoiceDelta represents a response structure for chat completion API. ```go type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` + // Content is the content of the response. + Content string `json:"content,omitempty"` + // Role is the role of the creator of the completion. + Role string `json:"role,omitempty"` + // FunctionCall is the function call of the response. FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` - ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + // ToolCalls are the tool calls of the response. + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` } ``` -## type [ChatCompletionStreamResponse]() +## type [ChatCompletionStreamResponse]() ChatCompletionStreamResponse represents a response structure for chat completion API. ```go type ChatCompletionStreamResponse struct { - ID string `json:"id"` // ID is the identifier for the chat completion stream response. - Object string `json:"object"` // Object is the object type of the chat completion stream response. - Created int64 `json:"created"` // Created is the creation time of the chat completion stream response. - Model ChatModel `json:"model"` // Model is the model used for the chat completion stream response. - Choices []ChatCompletionStreamChoice `json:"choices"` // Choices is the choices for the chat completion stream response. - SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint for the chat completion stream response. - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` // PromptAnnotations is the prompt annotations for the chat completion stream response. + // ID is the identifier for the chat completion stream response. + ID string `json:"id"` + // Object is the object type of the chat completion stream + // response. + Object string `json:"object"` + // Created is the creation time of the chat completion stream + // response. + Created int64 `json:"created"` + // Model is the model used for the chat completion stream + // response. + Model models.ChatModel `json:"model"` + // Choices is the choices for the chat completion stream + // response. + Choices []ChatCompletionStreamChoice `json:"choices"` + // SystemFingerprint is the system fingerprint for the chat + // completion stream response. + SystemFingerprint string `json:"system_fingerprint"` + // PromptAnnotations is the prompt annotations for the chat + // completion stream response. + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + // PromptFilterResults is the prompt filter results for the chat + // completion stream response. PromptFilterResults []struct { Index int `json:"index"` - } `json:"prompt_filter_results,omitempty"` // PromptFilterResults is the prompt filter results for the chat completion stream response. - // Usage is an optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + } `json:"prompt_filter_results,omitempty"` + // Usage is an optional field that will only be present when you + // set stream_options: {"include_usage": true} in your request. // - // When present, it contains a null value except for the last chunk which contains the token usage statistics - // for the entire request. + // When present, it contains a null value except for the last + // chunk which contains the token usage statistics for the + // entire request. Usage *Usage `json:"usage,omitempty"` } ``` -## type [ChatMessageImageURL]() +## type [ChatMessageImageURL]() ChatMessageImageURL represents the chat message image url. ```go type ChatMessageImageURL struct { - URL string `json:"url,omitempty"` // URL is the url of the image. - Detail ImageURLDetail `json:"detail,omitempty"` // Detail is the detail of the image url. + // URL is the url of the image. + URL string `json:"url,omitempty"` + // Detail is the detail of the image url. + Detail ImageURLDetail `json:"detail,omitempty"` } ``` -## type [ChatMessagePart]() +## type [ChatMessagePart]() ChatMessagePart represents the chat message part of a chat completion message. ```go type ChatMessagePart struct { - Text string `json:"text,omitempty"` // Text is the text of the chat message part. - Type ChatMessagePartType `json:"type,omitempty"` // Type is the type of the chat message part. - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` // ImageURL is the image url of the chat message part. + // Text is the text of the chat message part. + Text string `json:"text,omitempty"` + // Type is the type of the chat message part. + Type ChatMessagePartType `json:"type,omitempty"` + // ImageURL is the image url of the chat message part. + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` } ``` -## type [ChatMessagePartType]() +## type [ChatMessagePartType]() ChatMessagePartType is the chat message part type. @@ -807,17 +623,8 @@ string type ChatMessagePartType string ``` - -## type [ChatModel]() - -ChatModel is the type for chat models present on the groq api. - -```go -type ChatModel model -``` - -## type [Client]() +## type [Client]() Client is a Groq api client. @@ -828,7 +635,7 @@ type Client struct { ``` -### func [NewClient]() +### func [NewClient]() ```go func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) @@ -837,7 +644,7 @@ func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) NewClient creates a new Groq client. -### func \(\*Client\) [CreateChatCompletion]() +### func \(\*Client\) [CreateChatCompletion]() ```go func (c *Client) CreateChatCompletion(ctx context.Context, request ChatCompletionRequest) (response ChatCompletionResponse, err error) @@ -846,7 +653,7 @@ func (c *Client) CreateChatCompletion(ctx context.Context, request ChatCompletio CreateChatCompletion method is an API call to create a chat completion. -### func \(\*Client\) [CreateChatCompletionJSON]() +### func \(\*Client\) [CreateChatCompletionJSON]() ```go func (c *Client) CreateChatCompletionJSON(ctx context.Context, request ChatCompletionRequest, output any) (err error) @@ -855,7 +662,7 @@ func (c *Client) CreateChatCompletionJSON(ctx context.Context, request ChatCompl CreateChatCompletionJSON method is an API call to create a chat completion w/ object output. -### func \(\*Client\) [CreateChatCompletionStream]() +### func \(\*Client\) [CreateChatCompletionStream]() ```go func (c *Client) CreateChatCompletionStream(ctx context.Context, request ChatCompletionRequest) (stream *ChatCompletionStream, err error) @@ -866,7 +673,7 @@ CreateChatCompletionStream method is an API call to create a chat completion w/ If set, tokens will be sent as data\-only server\-sent events as they become available, with the stream terminated by a data: \[DONE\] message. -### func \(\*Client\) [CreateTranscription]() +### func \(\*Client\) [CreateTranscription]() ```go func (c *Client) CreateTranscription(ctx context.Context, request AudioRequest) (AudioResponse, error) @@ -877,7 +684,7 @@ CreateTranscription calls the transcriptions endpoint with the given request. Returns transcribed text in the response\_format specified in the request. -### func \(\*Client\) [CreateTranslation]() +### func \(\*Client\) [CreateTranslation]() ```go func (c *Client) CreateTranslation(ctx context.Context, request AudioRequest) (AudioResponse, error) @@ -888,112 +695,58 @@ CreateTranslation calls the translations endpoint with the given request. Returns the translated text in the response\_format specified in the request. -### func \(\*Client\) [Moderate]() +### func \(\*Client\) [Moderate]() ```go -func (c *Client) Moderate(ctx context.Context, request ModerationRequest) (response Moderation, err error) +func (c *Client) Moderate(ctx context.Context, messages []ChatCompletionMessage, model models.ModerationModel) (response Moderation, err error) ``` Moderate performs a moderation api call over a string. Input can be an array or slice but a string will reduce the complexity. - -### func \(\*Client\) [MustCreateChatCompletion]() - -```go -func (c *Client) MustCreateChatCompletion(ctx context.Context, request ChatCompletionRequest) (response ChatCompletionResponse) -``` - -MustCreateChatCompletion method is an API call to create a chat completion. - -It panics if an error occurs. - - -## type [DefaultErrorAccumulator]() - -DefaultErrorAccumulator is a default implementation of ErrorAccumulator - -```go -type DefaultErrorAccumulator struct { - Buffer errorBuffer -} -``` - - -### func \(\*DefaultErrorAccumulator\) [Bytes]() - -```go -func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) -``` - -Bytes method returns the bytes of the error accumulator. - - -### func \(\*DefaultErrorAccumulator\) [Write]() - -```go -func (e *DefaultErrorAccumulator) Write(p []byte) error -``` - -Write method writes bytes to the error accumulator. - -## type [Endpoint]() +## type [Endpoint]() -Endpoint is the endpoint for the groq api. string +Endpoint is an endpoint for the groq api. ```go type Endpoint string ``` - -## type [ErrContentFieldsMisused]() - -ErrContentFieldsMisused is an error that occurs when both Content and MultiContent properties are set. - -```go -type ErrContentFieldsMisused struct { - // contains filtered or unexported fields -} -``` - - -### func \(ErrContentFieldsMisused\) [Error]() - -```go -func (e ErrContentFieldsMisused) Error() string -``` - -Error implements the error interface. - - -## type [ErrTooManyEmptyStreamMessages]() - -ErrTooManyEmptyStreamMessages is returned when the stream has sent too many empty messages. + +## type [FinishReason]() -```go -type ErrTooManyEmptyStreamMessages struct{} -``` +FinishReason is the finish reason. - -### func \(ErrTooManyEmptyStreamMessages\) [Error]() +string ```go -func (e ErrTooManyEmptyStreamMessages) Error() string +type FinishReason string ``` -Error returns the error message. - - -## type [FinishReason]() - -FinishReason is the finish reason. string + ```go -type FinishReason string +const ( + // ReasonStop is the stop finish reason for a chat completion. + ReasonStop FinishReason = "stop" + // ReasonLength is the length finish reason for a chat completion. + ReasonLength FinishReason = "length" + // ReasonFunctionCall is the function call finish reason for a chat + // completion. + ReasonFunctionCall FinishReason = "function_call" + // ReasonToolCalls is the tool calls finish reason for a chat + // completion. + ReasonToolCalls FinishReason = "tool_calls" + // ReasonContentFilter is the content filter finish reason for a chat + // completion. + ReasonContentFilter FinishReason = "content_filter" + // ReasonNull is the null finish reason for a chat completion. + ReasonNull FinishReason = "null" +) ``` -### func \(FinishReason\) [MarshalJSON]() +### func \(FinishReason\) [MarshalJSON]() ```go func (r FinishReason) MarshalJSON() ([]byte, error) @@ -1002,7 +755,7 @@ func (r FinishReason) MarshalJSON() ([]byte, error) MarshalJSON implements the json.Marshaler interface. -## type [Format]() +## type [Format]() Format is the format of a response. string @@ -1020,101 +773,29 @@ const ( // FormatJSON is the JSON format. There is no support for streaming with // JSON format selected. FormatJSON Format = "json" -) -``` - - -## type [HarmfulCategory]() - -HarmfulCategory is a category of harmful content that can be used to categorize a chat history using Llama Guard 3. - -\[Llama Guard 3\]\(https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3/\) - -Llama Guard 3 provides content safety support for the following languages: English, French, German, Hindi, Italian, Portuguese, Spanish, and Thai. - -string - -```go -type HarmfulCategory string -``` - - - -```go -const ( - // CategoryViolentCrimes (S1) is the violent crimes category. - // - // Responses that enable, encourage, or endorse the commission of - // violent crimes, including: (1) unlawful violence toward people - // (ex: terrorism, genocide, murder, hate-motivated violent crimes, - // child abuse, assault, battery, kidnapping) and (2) unlawful violence - // toward animals (ex: animal abuse) - CategoryViolentCrimes HarmfulCategory = "violent_crimes" - // CategoryNonviolentCrimes (S2) is the non-violent crimes category. - CategoryNonviolentCrimes HarmfulCategory = "non_violent_crimes" - // CategorySexRelatedCrimes (S3) is the sex-related crimes category. - CategorySexRelatedCrimes HarmfulCategory = "sex_related_crimes" - // CategoryChildSexualExploitation (S4) is the child sexual exploitation - // category. - // - // Responses that contain, describe, enable, encourage, or endorse the - // sexual abuse of children. - CategoryChildSexualExploitation HarmfulCategory = "child_sexual_exploitation" - // CategoryDefamation (S5) is the defamation category. - // - // Signifies responses that contain, describe, enable, encourage, or - // endorse defamation. - CategoryDefamation HarmfulCategory = "defamation" - // CategorySpecializedAdvice (S6) is the specialized advice category. - // - // Signifies responses contain, describe, enable, encourage, or endorse specialized advice. - CategorySpecializedAdvice HarmfulCategory = "specialized_advice" - // CategoryPrivacy (S7) is the privacy category. - // - // Responses contain, describe, enable, encourage, or endorse privacy. - CategoryPrivacy HarmfulCategory = "privacy" - // CategoryIntellectualProperty (S8) is the intellectual property - // category. Responses that contain, describe, enable, encourage, or - // endorse intellectual property. - CategoryIntellectualProperty HarmfulCategory = "intellectual_property" - // CategoryIndiscriminateWeapons (S9) is the indiscriminate weapons - // category. - // - // Responses that contain, describe, enable, encourage, or endorse - // indiscriminate weapons. - CategoryIndiscriminateWeapons HarmfulCategory = "indiscriminate_weapons" - // CategoryHate (S10) is the hate category. - // - // Responses contain, describe, enable, encourage, or endorse hate. - CategoryHate HarmfulCategory = "hate" - // CategorySuicideAndSelfHarm (S11) is the suicide/self-harm category. - // - // Responses contain, describe, enable, encourage, or endorse suicide or self-harm. - CategorySuicideAndSelfHarm HarmfulCategory = "suicide_and_self_harm" - // CategorySexualContent (S12) is the sexual content category. - // - // Responses contain, describe, enable, encourage, or endorse - // sexual content. - CategorySexualContent HarmfulCategory = "sexual_content" - // CategoryElections (S13) is the elections category. - // - // Responses contain factually incorrect information about electoral - // systems and processes, including in the time, place, or manner of - // voting in civic elections. - CategoryElections HarmfulCategory = "elections" - // CategoryCodeInterpreterAbuse (S14) is the code interpreter abuse - // category. - // - // Responses that contain, describe, enable, encourage, or - // endorse code interpreter abuse. - CategoryCodeInterpreterAbuse HarmfulCategory = "code_interpreter_abuse" + // FormatSRT is the SRT format. This is a text format that is only + // supported for the transcription API. + // SRT format selected. + FormatSRT Format = "srt" + // FormatVTT is the VTT format. This is a text format that is only + // supported for the transcription API. + FormatVTT Format = "vtt" + // FormatVerboseJSON is the verbose JSON format. This is a JSON format + // that is only supported for the transcription API. + FormatVerboseJSON Format = "verbose_json" + // FormatJSONObject is the json object chat + // completion response format type. + FormatJSONObject Format = "json_object" + // FormatJSONSchema is the json schema chat + // completion response format type. + FormatJSONSchema Format = "json_schema" ) ``` -## type [ImageURLDetail]() +## type [ImageURLDetail]() -ImageURLDetail is the image url detail. +ImageURLDetail is the detail of the image at the URL. string @@ -1123,57 +804,46 @@ type ImageURLDetail string ``` -## type [LogProbs]() +## type [LogProbs]() LogProbs is the top\-level structure containing the log probability information. ```go type LogProbs struct { + // Content is a list of message content tokens with log + // probability information. Content []struct { - Token string `json:"token"` // Token is the token of the log prob. - LogProb float64 `json:"logprob"` // LogProb is the log prob of the log prob. - Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null - TopLogProbs []TopLogProbs `json:"top_logprobs"` // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested top_logprobs returned. - } `json:"content"` // Content is a list of message content tokens with log probability information. + // Token is the token of the log prob. + Token string `json:"token"` + // LogProb is the log prob of the log prob. + LogProb float64 `json:"logprob"` + // Omitting the field if it is null + Bytes []byte `json:"bytes,omitempty"` + // TopLogProbs is a list of the most likely tokens and + // their log probability, at this token position. In + // rare cases, there may be fewer than the number of + // requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` + } `json:"content"` } ``` -## type [Moderation]() +## type [Moderation]() -Moderation represents one of possible moderation results. +Moderation represents the response of a moderation request. ```go type Moderation struct { - Categories []HarmfulCategory `json:"categories"` // Categories is the categories of the result. - Flagged bool `json:"flagged"` // Flagged is the flagged of the result. -} -``` - - -## type [ModerationModel]() - -ModerationModel is the type for moderation models present on the groq api. - -```go -type ModerationModel model -``` - - -## type [ModerationRequest]() - -ModerationRequest represents a request structure for moderation API. - -```go -type ModerationRequest struct { - // Input string `json:"input,omitempty"` // Input is the input text to be moderated. - Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model. - Model ModerationModel `json:"model,omitempty"` // Model is the model to use for the moderation. + // Categories is the categories of the result. + Categories []moderation.HarmfulCategory `json:"categories"` + // Flagged is the flagged status of the result. + Flagged bool `json:"flagged"` } ``` -## type [Opts]() +## type [Opts]() Opts is a function that sets options for a Groq client. @@ -1182,7 +852,7 @@ type Opts func(*Client) ``` -### func [WithBaseURL]() +### func [WithBaseURL]() ```go func WithBaseURL(baseURL string) Opts @@ -1191,7 +861,7 @@ func WithBaseURL(baseURL string) Opts WithBaseURL sets the base URL for the Groq client. -### func [WithClient]() +### func [WithClient]() ```go func WithClient(client *http.Client) Opts @@ -1200,7 +870,7 @@ func WithClient(client *http.Client) Opts WithClient sets the client for the Groq client. -### func [WithLogger]() +### func [WithLogger]() ```go func WithLogger(logger *slog.Logger) Opts @@ -1209,7 +879,7 @@ func WithLogger(logger *slog.Logger) Opts WithLogger sets the logger for the Groq client. -## type [PromptAnnotation]() +## type [PromptAnnotation]() PromptAnnotation represents the prompt annotation. @@ -1220,23 +890,33 @@ type PromptAnnotation struct { ``` -## type [RateLimitHeaders]() +## type [RateLimitHeaders]() RateLimitHeaders struct represents Groq rate limits headers. ```go type RateLimitHeaders struct { - LimitRequests int `json:"x-ratelimit-limit-requests"` // LimitRequests is the limit requests of the rate limit headers. - LimitTokens int `json:"x-ratelimit-limit-tokens"` // LimitTokens is the limit tokens of the rate limit headers. - RemainingRequests int `json:"x-ratelimit-remaining-requests"` // RemainingRequests is the remaining requests of the rate limit headers. - RemainingTokens int `json:"x-ratelimit-remaining-tokens"` // RemainingTokens is the remaining tokens of the rate limit headers. - ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` // ResetRequests is the reset requests of the rate limit headers. - ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` // ResetTokens is the reset tokens of the rate limit headers. + // LimitRequests is the limit requests of the rate limit + // headers. + LimitRequests int `json:"x-ratelimit-limit-requests"` + // LimitTokens is the limit tokens of the rate limit headers. + LimitTokens int `json:"x-ratelimit-limit-tokens"` + // RemainingRequests is the remaining requests of the rate + // limit headers. + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + // RemainingTokens is the remaining tokens of the rate limit + // headers. + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + // ResetRequests is the reset requests of the rate limit + // headers. + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + // ResetTokens is the reset tokens of the rate limit headers. + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` } ``` -## type [ResetTime]() +## type [ResetTime]() ResetTime is a time.Time wrapper for the rate limit reset time. string @@ -1245,7 +925,7 @@ type ResetTime string ``` -### func \(ResetTime\) [String]() +### func \(ResetTime\) [String]() ```go func (r ResetTime) String() string @@ -1254,7 +934,7 @@ func (r ResetTime) String() string String returns the string representation of the ResetTime. -### func \(ResetTime\) [Time]() +### func \(ResetTime\) [Time]() ```go func (r ResetTime) Time() time.Time @@ -1263,7 +943,7 @@ func (r ResetTime) Time() time.Time Time returns the time.Time representation of the ResetTime. -## type [Role]() +## type [Role]() Role is the role of the chat completion message. @@ -1274,56 +954,116 @@ type Role string ``` -## type [Segments]() +## type [Segments]() Segments is the segments of the response. ```go type Segments []struct { - ID int `json:"id"` // ID is the ID of the segment. - Seek int `json:"seek"` // Seek is the seek of the segment. - Start float64 `json:"start"` // Start is the start of the segment. - End float64 `json:"end"` // End is the end of the segment. - Text string `json:"text"` // Text is the text of the segment. - Tokens []int `json:"tokens"` // Tokens is the tokens of the segment. - Temperature float64 `json:"temperature"` // Temperature is the temperature of the segment. - AvgLogprob float64 `json:"avg_logprob"` // AvgLogprob is the avg log prob of the segment. - CompressionRatio float64 `json:"compression_ratio"` // CompressionRatio is the compression ratio of the segment. - NoSpeechProb float64 `json:"no_speech_prob"` // NoSpeechProb is the no speech prob of the segment. - Transient bool `json:"transient"` // Transient is the transient of the segment. + // ID is the ID of the segment. + ID int `json:"id"` + // Seek is the seek of the segment. + Seek int `json:"seek"` + // Start is the start of the segment. + Start float64 `json:"start"` + // End is the end of the segment. + End float64 `json:"end"` + // Text is the text of the segment. + Text string `json:"text"` + // Tokens is the tokens of the segment. + Tokens []int `json:"tokens"` + // Temperature is the temperature of the segment. + Temperature float64 `json:"temperature"` + // AvgLogprob is the avg log prob of the segment. + AvgLogprob float64 `json:"avg_logprob"` + // CompressionRatio is the compression ratio of the segment. + CompressionRatio float64 `json:"compression_ratio"` + // NoSpeechProb is the no speech prob of the segment. + NoSpeechProb float64 `json:"no_speech_prob"` + // Transient is the transient of the segment. + Transient bool `json:"transient"` } ``` -## type [StreamOptions]() +## type [StreamOptions]() StreamOptions represents the stream options. ```go type StreamOptions struct { - // If set, an additional chunk will be streamed before the data: [DONE] message. - // The usage field on this chunk shows the token usage statistics for the entire request, - // and the choices field will always be an empty array. - // All other chunks will also include a usage field, but with a null value. + // IncludeUsage is the include usage option of the stream + // options. + // + // If set, an additional chunk will be streamed before the data: + // [DONE] message. + // The usage field on this chunk shows the token usage + // statistics for the entire request, and the choices field will + // always be an empty array. + // + // All other chunks will also include a usage field, but with a + // null value. IncludeUsage bool `json:"include_usage,omitempty"` } ``` + +## type [ToolGetter]() + +ToolGetter is an interface for a tool getter. + +```go +type ToolGetter interface { + Get( + ctx context.Context, + ) ([]tools.Tool, error) +} +``` + + +## type [ToolManager]() + +ToolManager is an interface for a tool manager. + +```go +type ToolManager interface { + ToolGetter + ToolRunner +} +``` + + +## type [ToolRunner]() + +ToolRunner is an interface for a tool runner. + +```go +type ToolRunner interface { + Run( + ctx context.Context, + response ChatCompletionResponse, + ) ([]ChatCompletionMessage, error) +} +``` + -## type [TopLogProbs]() +## type [TopLogProbs]() TopLogProbs represents the top log probs. ```go type TopLogProbs struct { - Token string `json:"token"` // Token is the token of the top log probs. - LogProb float64 `json:"logprob"` // LogProb is the log prob of the top log probs. - Bytes []byte `json:"bytes,omitempty"` // Bytes is the bytes of the top log probs. + // Token is the token of the top log probs. + Token string `json:"token"` + // LogProb is the log prob of the top log probs. + LogProb float64 `json:"logprob"` + // Bytes is the bytes of the top log probs. + Bytes []byte `json:"bytes,omitempty"` } ``` -## type [TranscriptionTimestampGranularity]() +## type [TranscriptionTimestampGranularity]() TranscriptionTimestampGranularity is the timestamp granularity for the transcription. @@ -1333,8 +1073,21 @@ string type TranscriptionTimestampGranularity string ``` + + +```go +const ( + // TranscriptionTimestampGranularityWord is the word timestamp + // granularity. + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + // TranscriptionTimestampGranularitySegment is the segment timestamp + // granularity. + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) +``` + -## type [Usage]() +## type [Usage]() Usage Represents the total token usage per request to Groq. @@ -1347,15 +1100,19 @@ type Usage struct { ``` -## type [Words]() +## type [Words]() -Words is the words of the response. +Words is the words of the audio response. ```go type Words []struct { - Word string `json:"word"` // Word is the word of the words. - Start float64 `json:"start"` // Start is the start of the words. - End float64 `json:"end"` // End is the end of the words. + // Word is the textual representation of a word in the audio + // response. + Word string `json:"word"` + // Start is the start of the words in seconds. + Start float64 `json:"start"` + // End is the end of the words in seconds. + End float64 `json:"end"` } ``` diff --git a/agents.go b/agents.go new file mode 100644 index 0000000..df5693e --- /dev/null +++ b/agents.go @@ -0,0 +1,63 @@ +package groq + +import ( + "context" + "log/slog" + + "github.com/conneroisu/groq-go/pkg/tools" +) + +type ( + // Agenter is an interface for an agent. + Agenter interface { + ToolManager + } + // ToolManager is an interface for a tool manager. + ToolManager interface { + ToolGetter + ToolRunner + } + // ToolGetter is an interface for a tool getter. + ToolGetter interface { + Get( + ctx context.Context, + params ToolGetParams, + ) ([]tools.Tool, error) + } + // ToolRunner is an interface for a tool runner. + ToolRunner interface { + Run( + ctx context.Context, + response ChatCompletionResponse, + ) ([]ChatCompletionMessage, error) + } + // ToolGetParams are the parameters for getting tools. + ToolGetParams struct { + } + // Router is an agent router. + // + // It is used to route messages to the appropriate model. + Router struct { + // Agents is the agents of the router. + Agents []Agent + // Logger is the logger of the router. + Logger *slog.Logger + } +) + +// Agent is an agent. +type Agent struct { + client *Client + logger *slog.Logger +} + +// NewAgent creates a new agent. +func NewAgent( + client *Client, + logger *slog.Logger, +) *Agent { + return &Agent{ + client: client, + logger: logger, + } +} diff --git a/audio.go b/audio.go index aecffee..a889b00 100644 --- a/audio.go +++ b/audio.go @@ -9,76 +9,103 @@ import ( "os" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/models" ) const ( - AudioResponseFormatJSON AudioResponseFormat = "json" // AudioResponseFormatJSON is the JSON response format of some audio. - AudioResponseFormatText AudioResponseFormat = "text" // AudioResponseFormatText is the text response format of some audio. - AudioResponseFormatSRT AudioResponseFormat = "srt" // AudioResponseFormatSRT is the SRT response format of some audio. - AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" // AudioResponseFormatVerboseJSON is the verbose JSON response format of some audio. - AudioResponseFormatVTT AudioResponseFormat = "vtt" // AudioResponseFormatVTT is the VTT response format of some audio. - - TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" // TranscriptionTimestampGranularityWord is the word timestamp granularity. - TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity. + // TranscriptionTimestampGranularityWord is the word timestamp + // granularity. + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + // TranscriptionTimestampGranularitySegment is the segment timestamp + // granularity. + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" ) type ( - // AudioResponseFormat is the response format for the audio API. - // - // Response formatted using AudioResponseFormatJSON by default. - // - // string - AudioResponseFormat string - // TranscriptionTimestampGranularity is the timestamp granularity for the transcription. + // TranscriptionTimestampGranularity is the timestamp granularity for + // the transcription. // // string TranscriptionTimestampGranularity string // AudioRequest represents a request structure for audio API. AudioRequest struct { - Model AudioModel // Model is the model to use for the transcription. - FilePath string // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. - Reader io.Reader // Reader is an optional io.Reader when you do not want to use an existing file. - Prompt string // Prompt is the prompt for the transcription. - Temperature float32 // Temperature is the temperature for the transcription. - Language string // Language is the language for the transcription. Only for transcription. - Format AudioResponseFormat // Format is the format for the response. + // Model is the model to use for the transcription. + Model models.AudioModel + // FilePath is either an existing file in your filesystem or a + // filename representing the contents of Reader. + FilePath string + // Reader is an optional io.Reader when you do not want to use + // an existing file. + Reader io.Reader + // Prompt is the prompt for the transcription. + Prompt string + // Temperature is the temperature for the transcription. + Temperature float32 + // Language is the language for the transcription. Only for + // transcription. + Language string + // Format is the format for the response. + Format Format } // AudioResponse represents a response structure for audio API. AudioResponse struct { - Task string `json:"task"` // Task is the task of the response. - Language string `json:"language"` // Language is the language of the response. - Duration float64 `json:"duration"` // Duration is the duration of the response. - Segments Segments `json:"segments"` // Segments is the segments of the response. - Words Words `json:"words"` // Words is the words of the response. - Text string `json:"text"` // Text is the text of the response. + // Task is the task of the response. + Task string `json:"task"` + // Language is the language of the response. + Language string `json:"language"` + // Duration is the duration of the response. + Duration float64 `json:"duration"` + // Segments is the segments of the response. + Segments Segments `json:"segments"` + // Words is the words of the response. + Words Words `json:"words"` + // Text is the text of the response. + Text string `json:"text"` Header http.Header // Header is the header of the response. } - // Words is the words of the response. + // Words is the words of the audio response. Words []struct { - Word string `json:"word"` // Word is the word of the words. - Start float64 `json:"start"` // Start is the start of the words. - End float64 `json:"end"` // End is the end of the words. + // Word is the textual representation of a word in the audio + // response. + Word string `json:"word"` + // Start is the start of the words in seconds. + Start float64 `json:"start"` + // End is the end of the words in seconds. + End float64 `json:"end"` } // Segments is the segments of the response. Segments []struct { - ID int `json:"id"` // ID is the ID of the segment. - Seek int `json:"seek"` // Seek is the seek of the segment. - Start float64 `json:"start"` // Start is the start of the segment. - End float64 `json:"end"` // End is the end of the segment. - Text string `json:"text"` // Text is the text of the segment. - Tokens []int `json:"tokens"` // Tokens is the tokens of the segment. - Temperature float64 `json:"temperature"` // Temperature is the temperature of the segment. - AvgLogprob float64 `json:"avg_logprob"` // AvgLogprob is the avg log prob of the segment. - CompressionRatio float64 `json:"compression_ratio"` // CompressionRatio is the compression ratio of the segment. - NoSpeechProb float64 `json:"no_speech_prob"` // NoSpeechProb is the no speech prob of the segment. - Transient bool `json:"transient"` // Transient is the transient of the segment. + // ID is the ID of the segment. + ID int `json:"id"` + // Seek is the seek of the segment. + Seek int `json:"seek"` + // Start is the start of the segment. + Start float64 `json:"start"` + // End is the end of the segment. + End float64 `json:"end"` + // Text is the text of the segment. + Text string `json:"text"` + // Tokens is the tokens of the segment. + Tokens []int `json:"tokens"` + // Temperature is the temperature of the segment. + Temperature float64 `json:"temperature"` + // AvgLogprob is the avg log prob of the segment. + AvgLogprob float64 `json:"avg_logprob"` + // CompressionRatio is the compression ratio of the segment. + CompressionRatio float64 `json:"compression_ratio"` + // NoSpeechProb is the no speech prob of the segment. + NoSpeechProb float64 `json:"no_speech_prob"` + // Transient is the transient of the segment. + Transient bool `json:"transient"` } // audioTextResponse is the response structure for the audio API when the // response format is text. audioTextResponse struct { - Text string `json:"text"` // Text is the text of the response. - header http.Header `json:"-"` // Header is the response header. + // Text is the text of the response. + Text string `json:"text"` + // Header is the response header. + header http.Header `json:"-"` } ) @@ -122,7 +149,7 @@ func (c *Client) callAudioAPI( endpointSuffix Endpoint, ) (response AudioResponse, err error) { var formBody bytes.Buffer - c.requestFormBuilder = c.createFormBuilder(&formBody) + c.requestFormBuilder = builders.NewFormBuilder(&formBody) err = AudioMultipartForm(request, c.requestFormBuilder) if err != nil { return AudioResponse{}, err @@ -131,7 +158,7 @@ func (c *Client) callAudioAPI( ctx, c.header, http.MethodPost, - c.fullURL(endpointSuffix, withModel(model(request.Model))), + c.fullURL(endpointSuffix, withModel(request.Model)), builders.WithBody(&formBody), builders.WithContentType(c.requestFormBuilder.FormDataContentType()), ) @@ -153,14 +180,14 @@ func (c *Client) callAudioAPI( } func (r AudioRequest) hasJSONResponse() bool { - return r.Format == "" || r.Format == AudioResponseFormatJSON || - r.Format == AudioResponseFormatVerboseJSON + return r.Format == "" || r.Format == FormatJSON || + r.Format == FormatVerboseJSON } // AudioMultipartForm creates a form with audio file contents and the name of // the model to use for audio processing. func AudioMultipartForm(request AudioRequest, b builders.FormBuilder) error { - err := CreateFileField(request, b) + err := createFileField(request, b) if err != nil { return err } @@ -202,9 +229,7 @@ func AudioMultipartForm(request AudioRequest, b builders.FormBuilder) error { return b.Close() } -// CreateFileField creates the "file" form field from either an existing file -// or by using the reader. -func CreateFileField( +func createFileField( request AudioRequest, b builders.FormBuilder, ) (err error) { diff --git a/audio_test.go b/audio_test.go index dabb781..02c126c 100644 --- a/audio_test.go +++ b/audio_test.go @@ -26,7 +26,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Prompt: "test", Temperature: 0.5, Language: "en", - Format: groq.AudioResponseFormatSRT, + Format: groq.FormatSRT, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -94,7 +94,7 @@ func TestCreateFileField(t *testing.T) { return mockFailedErr }, } - err := groq.CreateFileField(req, mockBuilder) + err := groq.AudioMultipartForm(req, mockBuilder) a.ErrorIs( err, mockFailedErr, @@ -116,7 +116,7 @@ func TestCreateFileField(t *testing.T) { }, } - err := groq.CreateFileField(req, mockBuilder) + err := groq.AudioMultipartForm(req, mockBuilder) a.ErrorIs( err, mockFailedErr, @@ -130,7 +130,7 @@ func TestCreateFileField(t *testing.T) { FilePath: "non_existing_file.wav", } mockBuilder := builders.NewFormBuilder(&test.FailingErrorBuffer{}) - err := groq.CreateFileField(req, mockBuilder) + err := groq.AudioMultipartForm(req, mockBuilder) a.Error( err, "createFileField using file should return error when open file fails", diff --git a/chat.go b/chat.go index 04a4a12..0b9e26e 100644 --- a/chat.go +++ b/chat.go @@ -1,46 +1,74 @@ package groq import ( - "bufio" - "bytes" "context" "encoding/json" "fmt" - "io" "net/http" "reflect" "strings" "time" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/groqerr" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/schema" + "github.com/conneroisu/groq-go/pkg/streams" "github.com/conneroisu/groq-go/pkg/tools" ) const ( - ChatMessageRoleSystem Role = "system" // ChatMessageRoleSystem is the system chat message role. - ChatMessageRoleUser Role = "user" // ChatMessageRoleUser is the user chat message role. - ChatMessageRoleAssistant Role = "assistant" // ChatMessageRoleAssistant is the assistant chat message role. - ChatMessageRoleFunction Role = "function" // ChatMessageRoleFunction is the function chat message role. - ChatMessageRoleTool Role = "tool" // ChatMessageRoleTool is the tool chat message role. - ImageURLDetailHigh ImageURLDetail = "high" // ImageURLDetailHigh is the high image url detail. - ImageURLDetailLow ImageURLDetail = "low" // ImageURLDetailLow is the low image url detail. - ImageURLDetailAuto ImageURLDetail = "auto" // ImageURLDetailAuto is the auto image url detail. - ChatMessagePartTypeText ChatMessagePartType = "text" // ChatMessagePartTypeText is the text chat message part type. - ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" // ChatMessagePartTypeImageURL is the image url chat message part type. - ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" // ChatCompletionResponseFormatTypeJSONObject is the json object chat completion response format type. - ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" // ChatCompletionResponseFormatTypeJSONSchema is the json schema chat completion response format type. - ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" // ChatCompletionResponseFormatTypeText is the text chat completion response format type. - FinishReasonStop FinishReason = "stop" // FinishReasonStop is the stop finish reason. - FinishReasonLength FinishReason = "length" // FinishReasonLength is the length finish reason. - FinishReasonFunctionCall FinishReason = "function_call" // FinishReasonFunctionCall is the function call finish reason. - FinishReasonToolCalls FinishReason = "tool_calls" // FinishReasonToolCalls is the tool calls finish reason. - FinishReasonContentFilter FinishReason = "content_filter" // FinishReasonContentFilter is the content filter finish reason. - FinishReasonNull FinishReason = "null" // FinishReasonNull is the null finish reason. + // ChatMessageRoleSystem is the system chat message role. + ChatMessageRoleSystem Role = "system" + // ChatMessageRoleUser is the user chat message role. + ChatMessageRoleUser Role = "user" + // ChatMessageRoleAssistant is the assistant chat message role. + ChatMessageRoleAssistant Role = "assistant" + // ChatMessageRoleFunction is the function chat message role. + ChatMessageRoleFunction Role = "function" + // ChatMessageRoleTool is the tool chat message role. + ChatMessageRoleTool Role = "tool" + + // ImageURLDetailHigh is the high image url detail. + ImageURLDetailHigh ImageURLDetail = "high" + // ImageURLDetailLow is the low image url detail. + ImageURLDetailLow ImageURLDetail = "low" + // ImageURLDetailAuto is the auto image url detail. + ImageURLDetailAuto ImageURLDetail = "auto" + + // ChatMessagePartTypeText is the text chat message part type. + ChatMessagePartTypeText ChatMessagePartType = "text" + // ChatMessagePartTypeImageURL is the image url chat message part type. + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" ) type ( - // ImageURLDetail is the image url detail. + // FinishReason is the finish reason. + // + // string + FinishReason string +) + +const ( + // ReasonStop is the stop finish reason for a chat completion. + ReasonStop FinishReason = "stop" + // ReasonLength is the length finish reason for a chat completion. + ReasonLength FinishReason = "length" + // ReasonFunctionCall is the function call finish reason for a chat + // completion. + ReasonFunctionCall FinishReason = "function_call" + // ReasonToolCalls is the tool calls finish reason for a chat + // completion. + ReasonToolCalls FinishReason = "tool_calls" + // ReasonContentFilter is the content filter finish reason for a chat + // completion. + ReasonContentFilter FinishReason = "content_filter" + // ReasonNull is the null finish reason for a chat completion. + ReasonNull FinishReason = "null" +) + +type ( + // ImageURLDetail is the detail of the image at the URL. // // string ImageURLDetail string @@ -58,35 +86,50 @@ type ( } // ChatMessageImageURL represents the chat message image url. ChatMessageImageURL struct { - URL string `json:"url,omitempty"` // URL is the url of the image. - Detail ImageURLDetail `json:"detail,omitempty"` // Detail is the detail of the image url. + // URL is the url of the image. + URL string `json:"url,omitempty"` + // Detail is the detail of the image url. + Detail ImageURLDetail `json:"detail,omitempty"` } // ChatMessagePart represents the chat message part of a chat completion // message. ChatMessagePart struct { - Text string `json:"text,omitempty"` // Text is the text of the chat message part. - Type ChatMessagePartType `json:"type,omitempty"` // Type is the type of the chat message part. - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` // ImageURL is the image url of the chat message part. + // Text is the text of the chat message part. + Text string `json:"text,omitempty"` + // Type is the type of the chat message part. + Type ChatMessagePartType `json:"type,omitempty"` + // ImageURL is the image url of the chat message part. + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` } // ChatCompletionMessage represents the chat completion message. ChatCompletionMessage struct { - Name string `json:"name"` // Name is the name of the chat completion message. - Role Role `json:"role"` // Role is the role of the chat completion message. - Content string `json:"content"` // Content is the content of the chat completion message. - MultiContent []ChatMessagePart `json:"-"` // MultiContent is the multi content of the chat completion message. - FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` // FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model. - ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` // ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. - ToolCallID string `json:"tool_call_id,omitempty"` // ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + // Name is the name of the chat completion message. + Name string `json:"name"` + // Role is the role of the chat completion message. + Role Role `json:"role"` + // Content is the content of the chat completion message. + Content string `json:"content"` + // MultiContent is the multi content of the chat completion + // message. + MultiContent []ChatMessagePart `json:"-"` + // FunctionCall setting for Role=assistant prompts this may be + // set to the function call generated by the model. + FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` + // ToolCalls setting for Role=assistant prompts this may be set + // to the tool calls generated by the model, such as function + // calls. + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + // ToolCallID is setting for Role=tool prompts this should be + // set to the ID given in the assistant's prior request to call + // a tool. + ToolCallID string `json:"tool_call_id,omitempty"` } - // ChatCompletionResponseFormatType is the chat completion response format type. - // - // string - ChatCompletionResponseFormatType string // ChatCompletionResponseFormat is the chat completion response format. ChatCompletionResponseFormat struct { // Type is the type of the chat completion response format. - Type ChatCompletionResponseFormatType `json:"type,omitempty"` - // JSONSchema is the json schema of the chat completion response format. + Type Format `json:"type,omitempty"` + // JSONSchema is the json schema of the chat completion response + // format. JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` } // ChatCompletionResponseFormatJSONSchema is the chat completion @@ -97,143 +140,220 @@ type ( // // it is used to further identify the schema in the response. Name string `json:"name"` - // Description is the description of the chat completion response - // format - // json schema. + // Description is the description of the chat completion + // response format json schema. Description string `json:"description,omitempty"` - // description of the chat completion response format json schema. - // Schema is the schema of the chat completion response format json schema. + // Schema is the schema of the chat completion response format + // json schema. Schema schema.Schema `json:"schema"` - // Strict determines whether to enforce the schema upon the generated - // content. + // Strict determines whether to enforce the schema upon the + // generated content. Strict bool `json:"strict"` } - // ChatCompletionRequest represents a request structure for the chat completion API. + // ChatCompletionRequest represents a request structure for the chat + // completion API. ChatCompletionRequest struct { - Model ChatModel `json:"model"` // Model is the model of the chat completion request. - Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model. - MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens is the max tokens of the chat completion request. - Temperature float32 `json:"temperature,omitempty"` // Temperature is the temperature of the chat completion request. - TopP float32 `json:"top_p,omitempty"` // TopP is the top p of the chat completion request. - N int `json:"n,omitempty"` // N is the n of the chat completion request. - Stream bool `json:"stream,omitempty"` // Stream is the stream of the chat completion request. - Stop []string `json:"stop,omitempty"` // Stop is the stop of the chat completion request. - PresencePenalty float32 `json:"presence_penalty,omitempty"` // PresencePenalty is the presence penalty of the chat completion request. - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` // ResponseFormat is the response format of the chat completion request. - Seed *int `json:"seed,omitempty"` // Seed is the seed of the chat completion request. - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // FrequencyPenalty is the frequency penalty of the chat completion request. - LogitBias map[string]int `json:"logit_bias,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. incorrect: `"logit_bias":{ "You": 6}`, correct: `"logit_bias":{"1639": 6}` refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias - LogProbs bool `json:"logprobs,omitempty"` // LogProbs indicates whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview model. - TopLogProbs int `json:"top_logprobs,omitempty"` // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - User string `json:"user,omitempty"` // User is the user of the chat completion request. - Tools []tools.Tool `json:"tools,omitempty"` // Tools is the tools of the chat completion request. - ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or an ToolChoice object. - StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Options for streaming response. Only set this when you set stream: true. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. - RetryDelay time.Duration `json:"-"` // RetryDelay is the delay between retries. + // Model is the model of the chat completion request. + Model models.ChatModel `json:"model"` + // Messages is the messages of the chat completion request. + // + // These act as the prompt for the model. + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens is the max tokens of the chat completion request. + MaxTokens int `json:"max_tokens,omitempty"` + // Temperature is the temperature of the chat completion + // request. + Temperature float32 `json:"temperature,omitempty"` + // TopP is the top p of the chat completion request. + TopP float32 `json:"top_p,omitempty"` + // N is the n of the chat completion request. + N int `json:"n,omitempty"` + // Stream is the stream of the chat completion request. + Stream bool `json:"stream,omitempty"` + // Stop is the stop of the chat completion request. + Stop []string `json:"stop,omitempty"` + // PresencePenalty is the presence penalty of the chat + // completion request. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + // ResponseFormat is the response format of the chat completion + // request. + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + // Seed is the seed of the chat completion request. + Seed *int `json:"seed,omitempty"` + // FrequencyPenalty is the frequency penalty of the chat + // completion request. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // LogitBias is must be a token id string (specified by their + // token ID in the tokenizer), not a word string. incorrect: `"logit_bias":{ "You": 6}`, correct: `"logit_bias":{"1639": 6}` refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + // LogProbs indicates whether to return log probabilities of the + // output tokens or not. If true, returns the log probabilities + // of each output token returned in the content of message. + // + // This option is currently not available on the + // gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the + // number of most likely tokens to return at each token + // position, each with an associated log probability. Logprobs + // must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + // User is the user of the chat completion request. + User string `json:"user,omitempty"` + // Tools is the tools of the chat completion request. + Tools []tools.Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // RetryDelay is the delay between retries. + RetryDelay time.Duration `json:"-"` } // LogProbs is the top-level structure containing the log probability information. LogProbs struct { + // Content is a list of message content tokens with log + // probability information. Content []struct { - Token string `json:"token"` // Token is the token of the log prob. - LogProb float64 `json:"logprob"` // LogProb is the log prob of the log prob. - Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null - TopLogProbs []TopLogProbs `json:"top_logprobs"` // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested top_logprobs returned. - } `json:"content"` // Content is a list of message content tokens with log probability information. + // Token is the token of the log prob. + Token string `json:"token"` + // LogProb is the log prob of the log prob. + LogProb float64 `json:"logprob"` + // Omitting the field if it is null + Bytes []byte `json:"bytes,omitempty"` + // TopLogProbs is a list of the most likely tokens and + // their log probability, at this token position. In + // rare cases, there may be fewer than the number of + // requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` + } `json:"content"` } // TopLogProbs represents the top log probs. TopLogProbs struct { - Token string `json:"token"` // Token is the token of the top log probs. - LogProb float64 `json:"logprob"` // LogProb is the log prob of the top log probs. - Bytes []byte `json:"bytes,omitempty"` // Bytes is the bytes of the top log probs. + // Token is the token of the top log probs. + Token string `json:"token"` + // LogProb is the log prob of the top log probs. + LogProb float64 `json:"logprob"` + // Bytes is the bytes of the top log probs. + Bytes []byte `json:"bytes,omitempty"` } - // FinishReason is the finish reason. - // string - FinishReason string // ChatCompletionChoice represents the chat completion choice. ChatCompletionChoice struct { Index int `json:"index"` // Index is the index of the choice. // Message is the chat completion message of the choice. - Message ChatCompletionMessage `json:"message"` // Message is the chat completion message of the choice. + Message ChatCompletionMessage `json:"message"` // FinishReason is the finish reason of the choice. - // - // stop: API returned complete message, - // or a message terminated by one of the stop sequences provided via the stop parameter - // length: Incomplete model output due to max_tokens parameter or token limit - // function_call: The model decided to call a function - // content_filter: Omitted content due to a flag from our content filters - // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` // FinishReason is the finish reason of the choice. + FinishReason FinishReason `json:"finish_reason"` // LogProbs is the log probs of the choice. // - // This is basically the probability of the model choosing the token. - LogProbs *LogProbs `json:"logprobs,omitempty"` // LogProbs is the log probs of the choice. + // This is basically the probability of the model choosing the + // token. + LogProbs *LogProbs `json:"logprobs,omitempty"` } - // ChatCompletionResponse represents a response structure for chat completion API. + // ChatCompletionResponse represents a response structure for chat + // completion API. ChatCompletionResponse struct { - ID string `json:"id"` // ID is the id of the response. - Object string `json:"object"` // Object is the object of the response. - Created int64 `json:"created"` // Created is the created time of the response. - Model ChatModel `json:"model"` // Model is the model of the response. - Choices []ChatCompletionChoice `json:"choices"` // Choices is the choices of the response. - Usage Usage `json:"usage"` // Usage is the usage of the response. - SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint of the response. - http.Header // Header is the header of the response. - } - // ChatCompletionStreamChoiceDelta represents a response structure for chat completion API. + // ID is the id of the response. + ID string `json:"id"` + // Object is the object of the response. + Object string `json:"object"` + // Created is the created time of the response. + Created int64 `json:"created"` + // Model is the model of the response. + Model models.ChatModel `json:"model"` + // Choices is the choices of the response. + Choices []ChatCompletionChoice `json:"choices"` + // Usage is the usage of the response. + Usage Usage `json:"usage"` + // SystemFingerprint is the system fingerprint of the response. + SystemFingerprint string `json:"system_fingerprint"` + // Header is the header of the response. + http.Header + } + // ChatCompletionStreamChoiceDelta represents a response structure for + // chat completion API. ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` + // Content is the content of the response. + Content string `json:"content,omitempty"` + // Role is the role of the creator of the completion. + Role string `json:"role,omitempty"` + // FunctionCall is the function call of the response. FunctionCall *tools.FunctionCall `json:"function_call,omitempty"` - ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` + // ToolCalls are the tool calls of the response. + ToolCalls []tools.ToolCall `json:"tool_calls,omitempty"` } - // ChatCompletionStreamChoice represents a response structure for chat completion API. + // ChatCompletionStreamChoice represents a response structure for chat + // completion API. ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` - } - streamer interface { - ChatCompletionStreamResponse + // Index is the index of the choice. + Index int `json:"index"` + // Delta is the delta of the choice. + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + // FinishReason is the finish reason of the choice. + FinishReason FinishReason `json:"finish_reason"` } // StreamOptions represents the stream options. StreamOptions struct { - // If set, an additional chunk will be streamed before the data: [DONE] message. - // The usage field on this chunk shows the token usage statistics for the entire request, - // and the choices field will always be an empty array. - // All other chunks will also include a usage field, but with a null value. + // IncludeUsage is the include usage option of the stream + // options. + // + // If set, an additional chunk will be streamed before the data: + // [DONE] message. + // The usage field on this chunk shows the token usage + // statistics for the entire request, and the choices field will + // always be an empty array. + // + // All other chunks will also include a usage field, but with a + // null value. IncludeUsage bool `json:"include_usage,omitempty"` } - // ChatCompletionStreamResponse represents a response structure for chat completion API. + // ChatCompletionStreamResponse represents a response structure for chat + // completion API. ChatCompletionStreamResponse struct { - ID string `json:"id"` // ID is the identifier for the chat completion stream response. - Object string `json:"object"` // Object is the object type of the chat completion stream response. - Created int64 `json:"created"` // Created is the creation time of the chat completion stream response. - Model ChatModel `json:"model"` // Model is the model used for the chat completion stream response. - Choices []ChatCompletionStreamChoice `json:"choices"` // Choices is the choices for the chat completion stream response. - SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint is the system fingerprint for the chat completion stream response. - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` // PromptAnnotations is the prompt annotations for the chat completion stream response. + // ID is the identifier for the chat completion stream response. + ID string `json:"id"` + // Object is the object type of the chat completion stream + // response. + Object string `json:"object"` + // Created is the creation time of the chat completion stream + // response. + Created int64 `json:"created"` + // Model is the model used for the chat completion stream + // response. + Model models.ChatModel `json:"model"` + // Choices is the choices for the chat completion stream + // response. + Choices []ChatCompletionStreamChoice `json:"choices"` + // SystemFingerprint is the system fingerprint for the chat + // completion stream response. + SystemFingerprint string `json:"system_fingerprint"` + // PromptAnnotations is the prompt annotations for the chat + // completion stream response. + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + // PromptFilterResults is the prompt filter results for the chat + // completion stream response. PromptFilterResults []struct { Index int `json:"index"` - } `json:"prompt_filter_results,omitempty"` // PromptFilterResults is the prompt filter results for the chat completion stream response. - // Usage is an optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + } `json:"prompt_filter_results,omitempty"` + // Usage is an optional field that will only be present when you + // set stream_options: {"include_usage": true} in your request. // - // When present, it contains a null value except for the last chunk which contains the token usage statistics - // for the entire request. + // When present, it contains a null value except for the last + // chunk which contains the token usage statistics for the + // entire request. Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream is a stream of ChatCompletionStreamResponse. - // - // Note: Perhaps it is more elegant to abstract Stream using generics. ChatCompletionStream struct { - *streamReader[ChatCompletionStreamResponse] + *streams.StreamReader[*ChatCompletionStreamResponse] } ) // MarshalJSON method implements the json.Marshaler interface. func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { if m.Content != "" && m.MultiContent != nil { - return nil, &ErrContentFieldsMisused{field: "Content"} + return nil, &groqerr.ErrContentFieldsMisused{} } if len(m.MultiContent) > 0 { msg := struct { @@ -294,7 +414,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) { // MarshalJSON implements the json.Marshaler interface. func (r FinishReason) MarshalJSON() ([]byte, error) { - if r == FinishReasonNull || r == "" { + if r == ReasonNull || r == "" { return []byte("null"), nil } return []byte( @@ -303,23 +423,7 @@ func (r FinishReason) MarshalJSON() ([]byte, error) { } // SetHeader sets the header of the response. -func (r *ChatCompletionResponse) SetHeader(h http.Header) { - r.Header = h -} - -// MustCreateChatCompletion method is an API call to create a chat completion. -// -// It panics if an error occurs. -func (c *Client) MustCreateChatCompletion( - ctx context.Context, - request ChatCompletionRequest, -) (response ChatCompletionResponse) { - response, err := c.CreateChatCompletion(ctx, request) - if err != nil { - panic(err) - } - return response -} +func (r *ChatCompletionResponse) SetHeader(h http.Header) { r.Header = h } // CreateChatCompletion method is an API call to create a chat completion. func (c *Client) CreateChatCompletion( @@ -331,22 +435,23 @@ func (c *Client) CreateChatCompletion( ctx, c.header, http.MethodPost, - c.fullURL(chatCompletionsSuffix, withModel(model(request.Model))), + c.fullURL(chatCompletionsSuffix, withModel(request.Model)), builders.WithBody(request)) if err != nil { return } err = c.sendRequest(req, &response) - reqErr, ok := err.(*APIError) - if ok && (reqErr.HTTPStatusCode == http.StatusServiceUnavailable || reqErr.HTTPStatusCode == http.StatusInternalServerError) { + reqErr, ok := err.(*groqerr.APIError) + if ok && (reqErr.HTTPStatusCode == http.StatusServiceUnavailable || + reqErr.HTTPStatusCode == http.StatusInternalServerError) { time.Sleep(request.RetryDelay) return c.CreateChatCompletion(ctx, request) } return } -// CreateChatCompletionStream method is an API call to create a chat completion w/ streaming -// support. +// CreateChatCompletionStream method is an API call to create a chat completion +// w/ streaming support. // // If set, tokens will be sent as data-only server-sent events as they become // available, with the stream terminated by a data: [DONE] message. @@ -359,23 +464,25 @@ func (c *Client) CreateChatCompletionStream( ctx, c.header, http.MethodPost, - c.fullURL(chatCompletionsSuffix, withModel(model(request.Model))), + c.fullURL( + chatCompletionsSuffix, + withModel(request.Model)), builders.WithBody(request), ) if err != nil { return nil, err } - resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) + resp, err := sendRequestStream(c, req) if err != nil { return } - stream = &ChatCompletionStream{ - streamReader: resp, - } - return + return &ChatCompletionStream{ + StreamReader: resp, + }, nil } -// CreateChatCompletionJSON method is an API call to create a chat completion w/ object output. +// CreateChatCompletionJSON method is an API call to create a chat completion +// w/ object output. func (c *Client) CreateChatCompletionJSON( ctx context.Context, request ChatCompletionRequest, @@ -392,20 +499,9 @@ func (c *Client) CreateChatCompletionJSON( Schema: *schema, Strict: true, } - req, err := builders.NewRequest( - ctx, - c.header, - http.MethodPost, - c.fullURL(chatCompletionsSuffix, withModel(model(request.Model))), - builders.WithBody(request), - ) - if err != nil { - return - } - var response ChatCompletionResponse - err = c.sendRequest(req, &response) + response, err := c.CreateChatCompletion(ctx, request) if err != nil { - reqErr, ok := err.(*APIError) + reqErr, ok := err.(*groqerr.APIError) if ok && (reqErr.HTTPStatusCode == http.StatusServiceUnavailable || reqErr.HTTPStatusCode == http.StatusInternalServerError) { time.Sleep(request.RetryDelay) @@ -427,85 +523,3 @@ func (c *Client) CreateChatCompletionJSON( } return } - -type streamReader[T streamer] struct { - emptyMessagesLimit uint - isFinished bool - reader *bufio.Reader - response *http.Response - errAccumulator errorAccumulator - Header http.Header // Header is the header of the response. -} - -// Recv receives a response from the stream. -func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF - return response, err - } - return stream.processLines() -} - -// processLines processes the lines of the current response in the stream. -func (stream *streamReader[T]) processLines() (T, error) { - var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) - emptyMessagesCount uint - hasErrorPrefix bool - ) - for { - rawLine, err := stream.reader.ReadBytes('\n') - if err != nil || hasErrorPrefix { - respErr := stream.unmarshalError() - if respErr != nil { - return *new(T), - fmt.Errorf("error, %w", respErr.Error) - } - return *new(T), err - } - noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { - hasErrorPrefix = true - } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { - if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) - } - err := stream.errAccumulator.Write(noSpaceLine) - if err != nil { - return *new(T), err - } - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - return *new(T), ErrTooManyEmptyStreamMessages{} - } - continue - } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) - if string(noPrefixLine) == "[DONE]" { - stream.isFinished = true - return *new(T), io.EOF - } - var response T - unmarshalErr := json.Unmarshal(noPrefixLine, &response) - if unmarshalErr != nil { - return *new(T), unmarshalErr - } - return response, nil - } -} -func (stream *streamReader[T]) unmarshalError() (errResp *errorResponse) { - errBytes := stream.errAccumulator.Bytes() - if len(errBytes) == 0 { - return - } - err := json.Unmarshal(errBytes, &errResp) - if err != nil { - errResp = nil - } - return -} -func (stream *streamReader[T]) Close() error { - return stream.response.Body.Close() -} diff --git a/chat_test.go b/chat_test.go index 1d51d1c..2e1c596 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,105 +1,65 @@ -package groq +package groq_test import ( - "bufio" - "bytes" - "io" + "context" + "encoding/json" "net/http" "testing" + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/test" + "github.com/conneroisu/groq-go/pkg/tools" "github.com/stretchr/testify/assert" ) -// TestStreamReaderReturnsUnmarshalerErrors tests the stream reader returns an unmarshaler error. -func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { - stream := &streamReader[ChatCompletionStreamResponse]{ - errAccumulator: newErrorAccumulator(), - } - - respErr := stream.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil with empty buffer: %v", respErr) - } - - err := stream.errAccumulator.Write([]byte("{")) - if err != nil { - t.Fatalf("%+v", err) - } - - respErr = stream.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) - } -} - -// TestStreamReaderReturnsErrTooManyEmptyStreamMessages tests the stream reader returns an error when the stream has too many empty messages. -func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { - a := assert.New(t) - stream := &streamReader[ChatCompletionStreamResponse]{ - emptyMessagesLimit: 3, - reader: bufio.NewReader( - bytes.NewReader([]byte("\n\n\n\n")), - ), - errAccumulator: newErrorAccumulator(), - } - _, err := stream.Recv() - a.ErrorIs( - err, - ErrTooManyEmptyStreamMessages{}, - "Did not return error when recv failed", - err.Error(), - ) -} - -// TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed tests the stream reader returns an error when the error accumulator fails to write. -func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { +func TestChat(t *testing.T) { + ctx := context.Background() a := assert.New(t) - stream := &streamReader[ChatCompletionStreamResponse]{ - reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), - errAccumulator: &DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, + ts := test.NewTestServer() + returnObj := groq.ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion.chunk", + Created: 1693721698, + Model: "llama3-groq-70b-8192-tool-use-preview", + Choices: []groq.ChatCompletionChoice{ + { + Index: 0, + Message: groq.ChatCompletionMessage{ + Role: groq.ChatMessageRoleAssistant, + Content: "Hello!", + }, + }, }, } - _, err := stream.Recv() - a.ErrorIs( - err, - test.ErrTestErrorAccumulatorWriteFailed{}, - "Did not return error when write failed", - err.Error(), + ts.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + jsval, err := json.Marshal(returnObj) + a.NoError(err) + _, err = w.Write(jsval) + if err != nil { + t.Fatal(err) + } + }) + testS := ts.GroqTestServer() + testS.Start() + client, err := groq.NewClient( + test.GetTestToken(), + groq.WithBaseURL(testS.URL+"/v1"), ) -} - -// Helper function to create a new `streamReader` for testing -func newStreamReader[T streamer](data string) *streamReader[T] { - resp := &http.Response{ - Body: io.NopCloser(bytes.NewBufferString(data)), - } - reader := bufio.NewReader(resp.Body) - - return &streamReader[T]{ - emptyMessagesLimit: 5, - isFinished: false, - reader: reader, - response: resp, - errAccumulator: newErrorAccumulator(), - Header: resp.Header, - } -} - -// Test the `Recv` method with multiple empty messages triggering an error -func TestStreamReader_TooManyEmptyMessages(t *testing.T) { - data := "\n\n\n\n\n\n" - stream := newStreamReader[ChatCompletionStreamResponse](data) - - _, err := stream.Recv() - assert.ErrorIs(t, err, ErrTooManyEmptyStreamMessages{}) -} - -// Test the `Close` method -func TestStreamReader_Close(t *testing.T) { - stream := newStreamReader[ChatCompletionStreamResponse]("") - - err := stream.Close() - assert.NoError(t, err) + a.NoError(err) + resp, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3Groq70B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + MaxTokens: 2000, + Tools: []tools.Tool{}, + }) + a.NoError(err) + a.NotEmpty(resp.Choices[0].Message.Content) } diff --git a/client.go b/client.go index 2db9617..d8c3c9c 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,6 @@ package groq import ( - "bufio" "encoding/json" "fmt" "io" @@ -11,55 +10,91 @@ import ( "time" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/groqerr" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/streams" ) //go:generate go run ./scripts/generate-models/ -//go:generate gomarkdoc -o README.md -e . +//go:generate go run github.com/princjef/gomarkdoc/cmd/gomarkdoc@v1.1.0 -o README.md -e . type ( - // Format is the format of a response. - // string - Format string // Client is a Groq api client. Client struct { - groqAPIKey string // Groq API key - orgID string // OrgID is the organization ID for the client. - baseURL string // Base URL for the client. - emptyMessagesLimit uint // EmptyMessagesLimit is the limit for the empty messages. + // Groq API key + groqAPIKey string + // OrgID is the organization ID for the client. + orgID string + // Base URL for the client. + baseURL string + // EmptyMessagesLimit is the limit for the empty messages. + emptyMessagesLimit uint header builders.Header requestFormBuilder builders.FormBuilder - createFormBuilder func(body io.Writer) builders.FormBuilder - client *http.Client // Client is the HTTP client to use - logger *slog.Logger // Logger is the logger for the client. + // Client is the HTTP client to use + client *http.Client + // Logger is the logger for the client. + logger *slog.Logger } + // Opts is a function that sets options for a Groq client. + Opts func(*Client) +) + +// WithClient sets the client for the Groq client. +func WithClient(client *http.Client) Opts { + return func(c *Client) { c.client = client } +} + +// WithBaseURL sets the base URL for the Groq client. +func WithBaseURL(baseURL string) Opts { + return func(c *Client) { c.baseURL = baseURL } +} + +// WithLogger sets the logger for the Groq client. +func WithLogger(logger *slog.Logger) Opts { + return func(c *Client) { c.logger = logger } +} + +type ( + // Format is the format of a response. + // string + Format string // RateLimitHeaders struct represents Groq rate limits headers. RateLimitHeaders struct { - LimitRequests int `json:"x-ratelimit-limit-requests"` // LimitRequests is the limit requests of the rate limit headers. - LimitTokens int `json:"x-ratelimit-limit-tokens"` // LimitTokens is the limit tokens of the rate limit headers. - RemainingRequests int `json:"x-ratelimit-remaining-requests"` // RemainingRequests is the remaining requests of the rate limit headers. - RemainingTokens int `json:"x-ratelimit-remaining-tokens"` // RemainingTokens is the remaining tokens of the rate limit headers. - ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` // ResetRequests is the reset requests of the rate limit headers. - ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` // ResetTokens is the reset tokens of the rate limit headers. + // LimitRequests is the limit requests of the rate limit + // headers. + LimitRequests int `json:"x-ratelimit-limit-requests"` + // LimitTokens is the limit tokens of the rate limit headers. + LimitTokens int `json:"x-ratelimit-limit-tokens"` + // RemainingRequests is the remaining requests of the rate + // limit headers. + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + // RemainingTokens is the remaining tokens of the rate limit + // headers. + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + // ResetRequests is the reset requests of the rate limit + // headers. + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + // ResetTokens is the reset tokens of the rate limit headers. + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` } // ResetTime is a time.Time wrapper for the rate limit reset time. // string ResetTime string - - // Opts is a function that sets options for a Groq client. - Opts func(*Client) // Usage Represents the total token usage per request to Groq. Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } + // Endpoint is an endpoint for the groq api. + Endpoint string - fullURLOptions struct { - model string - } - fullURLOption func(*fullURLOptions) + fullURLOptions struct{ model string } + fullURLOption func(*fullURLOptions) + response interface{ SetHeader(http.Header) } ) const ( @@ -69,9 +104,31 @@ const ( // FormatJSON is the JSON format. There is no support for streaming with // JSON format selected. FormatJSON Format = "json" + // FormatSRT is the SRT format. This is a text format that is only + // supported for the transcription API. + // SRT format selected. + FormatSRT Format = "srt" + // FormatVTT is the VTT format. This is a text format that is only + // supported for the transcription API. + FormatVTT Format = "vtt" + // FormatVerboseJSON is the verbose JSON format. This is a JSON format + // that is only supported for the transcription API. + FormatVerboseJSON Format = "verbose_json" + // FormatJSONObject is the json object chat + // completion response format type. + FormatJSONObject Format = "json_object" + // FormatJSONSchema is the json schema chat + // completion response format type. + FormatJSONSchema Format = "json_schema" // groqAPIURLv1 is the base URL for the Groq API. groqAPIURLv1 = "https://api.groq.com/openai/v1" + + chatCompletionsSuffix Endpoint = "/chat/completions" + transcriptionsSuffix Endpoint = "/audio/transcriptions" + translationsSuffix Endpoint = "/audio/translations" + embeddingsSuffix Endpoint = "/embeddings" + moderationsSuffix Endpoint = "/moderations" ) // NewClient creates a new Groq client. @@ -85,9 +142,6 @@ func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) { logger: slog.Default(), baseURL: groqAPIURLv1, emptyMessagesLimit: 10, - createFormBuilder: func(body io.Writer) builders.FormBuilder { - return builders.NewFormBuilder(body) - }, } for _, opt := range opts { opt(c) @@ -111,32 +165,6 @@ func (c *Client) fullURL(suffix Endpoint, setters ...fullURLOption) string { return fmt.Sprintf("%s%s", baseURL, suffix) } -// WithClient sets the client for the Groq client. -func WithClient(client *http.Client) Opts { - return func(c *Client) { - c.client = client - } -} - -// WithBaseURL sets the base URL for the Groq client. -func WithBaseURL(baseURL string) Opts { - return func(c *Client) { - c.baseURL = baseURL - } -} - -// WithLogger sets the logger for the Groq client. -func WithLogger(logger *slog.Logger) Opts { - return func(c *Client) { - c.logger = logger - } -} - -// response is an interface for a response. -type response interface { - SetHeader(http.Header) -} - func (c *Client) sendRequest(req *http.Request, v response) error { req.Header.Set("Accept", "application/json") @@ -165,10 +193,10 @@ func (c *Client) sendRequest(req *http.Request, v response) error { return decodeResponse(res.Body, v) } -func sendRequestStream[T streamer]( +func sendRequestStream[T streams.Streamer[ChatCompletionStreamResponse]]( client *Client, req *http.Request, -) (*streamReader[T], error) { +) (*streams.StreamReader[*ChatCompletionStreamResponse], error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") @@ -178,18 +206,16 @@ func sendRequestStream[T streamer]( req, ) //nolint:bodyclose // body is closed in stream.Close() if err != nil { - return new(streamReader[T]), err + return new(streams.StreamReader[*ChatCompletionStreamResponse]), err } if isFailureStatusCode(resp) { - return new(streamReader[T]), client.handleErrorResp(resp) + return new(streams.StreamReader[*ChatCompletionStreamResponse]), client.handleErrorResp(resp) } - return &streamReader[T]{ - emptyMessagesLimit: client.emptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: newErrorAccumulator(), - Header: resp.Header, - }, nil + return streams.NewStreamReader[ChatCompletionStreamResponse]( + resp.Body, + resp.Header, + client.emptyMessagesLimit, + ), nil } func isFailureStatusCode(resp *http.Response) bool { @@ -221,17 +247,19 @@ func decodeString(body io.Reader, output *string) error { return nil } -func withModel(model model) fullURLOption { +func withModel[ + T models.ChatModel | models.AudioModel | models.ModerationModel, +](model T) fullURLOption { return func(args *fullURLOptions) { args.model = string(model) } } func (c *Client) handleErrorResp(resp *http.Response) error { - var errRes errorResponse + var errRes groqerr.ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { - reqErr := &requestError{ + reqErr := &groqerr.ErrRequest{ HTTPStatusCode: resp.StatusCode, Err: err, } diff --git a/client_test.go b/client_test.go index 7784de2..1a32116 100644 --- a/client_test.go +++ b/client_test.go @@ -1,21 +1,22 @@ -package groq +package groq_test import ( "log/slog" "net/http" "testing" + groq "github.com/conneroisu/groq-go" "github.com/stretchr/testify/assert" ) // TestClient tests the creation of a new client. func TestClient(t *testing.T) { a := assert.New(t) - client, err := NewClient( + client, err := groq.NewClient( "test", - WithBaseURL("http://localhost/v1"), - WithClient(http.DefaultClient), - WithLogger(slog.Default()), + groq.WithBaseURL("http://localhost/v1"), + groq.WithClient(http.DefaultClient), + groq.WithLogger(slog.Default()), ) a.NoError(err) a.NotNil(client) diff --git a/errors.go b/errors.go deleted file mode 100644 index 93b0fe0..0000000 --- a/errors.go +++ /dev/null @@ -1,151 +0,0 @@ -package groq - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "strings" -) - -type ( - // DefaultErrorAccumulator is a default implementation of ErrorAccumulator - DefaultErrorAccumulator struct { - Buffer errorBuffer - } - // APIError provides error information returned by the Groq API. - APIError struct { - Code any `json:"code,omitempty"` // Code is the code of the error. - Message string `json:"message"` // Message is the message of the error. - Param *string `json:"param,omitempty"` // Param is the param of the error. - Type string `json:"type"` // Type is the type of the error. - HTTPStatusCode int `json:"-"` // HTTPStatusCode is the status code of the error. - } - // ErrContentFieldsMisused is an error that occurs when both Content and - // MultiContent properties are set. - ErrContentFieldsMisused struct { - field string - } - // ErrTooManyEmptyStreamMessages is returned when the stream has sent too many - // empty messages. - ErrTooManyEmptyStreamMessages struct{} - errorAccumulator interface { - // Write method writes bytes to the error accumulator - // - // It implements the io.Writer interface. - Write(p []byte) error - // Bytes method returns the bytes of the error accumulator. - Bytes() []byte - } - errorBuffer interface { - io.Writer - Len() int - Bytes() []byte - } - requestError struct { - HTTPStatusCode int - Err error - } - errorResponse struct { - Error *APIError `json:"error,omitempty"` - } -) - -// Error implements the error interface. -func (e ErrContentFieldsMisused) Error() string { - return fmt.Errorf("can't use both Content and MultiContent properties simultaneously"). - Error() -} - -// Error returns the error message. -func (e ErrTooManyEmptyStreamMessages) Error() string { - return "stream has sent too many empty messages" -} - -// newErrorAccumulator creates a new error accumulator -func newErrorAccumulator() errorAccumulator { - return &DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, - } -} - -// Write method writes bytes to the error accumulator. -func (e *DefaultErrorAccumulator) Write(p []byte) error { - _, err := e.Buffer.Write(p) - if err != nil { - return fmt.Errorf("error accumulator write error, %w", err) - } - return nil -} - -// Bytes method returns the bytes of the error accumulator. -func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { - if e.Buffer.Len() == 0 { - return - } - errBytes = e.Buffer.Bytes() - return -} - -// Error method implements the error interface on APIError. -func (e *APIError) Error() string { - if e.HTTPStatusCode > 0 { - return fmt.Sprintf( - "error, status code: %d, message: %s", - e.HTTPStatusCode, - e.Message, - ) - } - return e.Message -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (e *APIError) UnmarshalJSON(data []byte) (err error) { - var rawMap map[string]json.RawMessage - err = json.Unmarshal(data, &rawMap) - if err != nil { - return - } - err = json.Unmarshal(rawMap["message"], &e.Message) - if err != nil { - var messages []string - err = json.Unmarshal(rawMap["message"], &messages) - if err != nil { - return - } - e.Message = strings.Join(messages, ", ") - } - // optional fields - if _, ok := rawMap["param"]; ok { - err = json.Unmarshal(rawMap["param"], &e.Param) - if err != nil { - return - } - } - if _, ok := rawMap["code"]; !ok { - return nil - } - // if the api returned a number, we need to force an integer - // since the json package defaults to float64 - var intCode int - err = json.Unmarshal(rawMap["code"], &intCode) - if err == nil { - e.Code = intCode - return nil - } - return json.Unmarshal(rawMap["code"], &e.Code) -} - -// Error implements the error interface. -func (e *requestError) Error() string { - return fmt.Sprintf( - "error, status code: %d, message: %s", - e.HTTPStatusCode, - e.Err, - ) -} - -// Unwrap unwraps the error. -func (e *requestError) Unwrap() error { - return e.Err -} diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 94d6868..0000000 --- a/errors_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package groq_test - -import ( - "bytes" - "errors" - "testing" - - groq "github.com/conneroisu/groq-go" - "github.com/conneroisu/groq-go/pkg/test" -) - -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &groq.DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, - } - - errBytes := accumulator.Bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) - } - - err := accumulator.Write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) - } - - errBytes = accumulator.Bytes() - if len(errBytes) == 0 { - t.Fatalf( - "Did not return error bytes when has error: %s", - string(errBytes), - ) - } -} - -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &groq.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, - } - err := accumulator.Write([]byte("{")) - if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed{}) { - t.Fatalf("Did not return error when write failed: %v", err) - } -} diff --git a/examples/audio-house-translation/main.go b/examples/audio-house-translation/main.go index b98347a..dc2c18e 100644 --- a/examples/audio-house-translation/main.go +++ b/examples/audio-house-translation/main.go @@ -1,3 +1,5 @@ +// Package main is an example of using the groq-go library to create a +// transcription/translation using the whisper-large-v3 model. package main import ( @@ -6,6 +8,7 @@ import ( "os" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) func main() { @@ -23,7 +26,7 @@ func run( return err } response, err := client.CreateTranslation(ctx, groq.AudioRequest{ - Model: groq.ModelWhisperLargeV3, + Model: models.ModelWhisperLargeV3, FilePath: "./house-speaks-mandarin.mp3", Prompt: "english and mandarin", }) diff --git a/examples/audio-lex-fridman/main.go b/examples/audio-lex-fridman/main.go index be662a1..41b025b 100644 --- a/examples/audio-lex-fridman/main.go +++ b/examples/audio-lex-fridman/main.go @@ -8,6 +8,7 @@ import ( "os" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) func main() { @@ -25,7 +26,7 @@ func run( return err } response, err := client.CreateTranscription(ctx, groq.AudioRequest{ - Model: groq.ModelWhisperLargeV3, + Model: models.ModelWhisperLargeV3, FilePath: "./The Roman Emperors who went insane Gregory Aldrete and Lex Fridman.mp3", }) if err != nil { diff --git a/examples/chat-terminal/main.go b/examples/chat-terminal/main.go index 5f063ca..17d216c 100644 --- a/examples/chat-terminal/main.go +++ b/examples/chat-terminal/main.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) var ( @@ -74,7 +75,7 @@ func input(ctx context.Context, r io.Reader, w io.Writer, client *groq.Client) e output, err := client.CreateChatCompletionStream( ctx, groq.ChatCompletionRequest{ - Model: groq.ModelGemma29BIt, + Model: models.ModelGemma29BIt, Messages: history, MaxTokens: 2000, }, @@ -89,7 +90,7 @@ func input(ctx context.Context, r io.Reader, w io.Writer, client *groq.Client) e if err != nil { return err } - if response.Choices[0].FinishReason == groq.FinishReasonStop { + if response.Choices[0].FinishReason == groq.ReasonStop { break } fmt.Fprint(writer, response.Choices[0].Delta.Content) diff --git a/examples/composio-github-star/main.go b/examples/composio-github-star/main.go index 19a230a..92ff50f 100644 --- a/examples/composio-github-star/main.go +++ b/examples/composio-github-star/main.go @@ -1,3 +1,6 @@ +// Package main is an example of using the composio client. +// +// It shows how to use the composio client to star a github repository. package main import ( @@ -8,6 +11,7 @@ import ( "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/composio" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/test" ) @@ -49,7 +53,7 @@ func run( return err } chat, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -65,7 +69,11 @@ Star the repo conneroisu/groq-go on GitHub. if err != nil { return err } - resp, err := comp.Run(ctx, chat) + user, err := comp.GetConnectedAccounts(ctx) + if err != nil { + return err + } + resp, err := comp.Run(ctx, user[0], chat) if err != nil { return err } diff --git a/examples/e2b-go-project/main.go b/examples/e2b-go-project/main.go index 9f31da2..8a2e51a 100644 --- a/examples/e2b-go-project/main.go +++ b/examples/e2b-go-project/main.go @@ -1,3 +1,4 @@ +// Package main shows an example of using the e2b extension. package main import ( @@ -7,6 +8,38 @@ import ( "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/e2b" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/tools" +) + +var ( + history = []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: ` +Given the callable tools provided, create a python project with the following files: + + +main.py +utils.py + + +The main function should call the "utils.run()" function. + +The project should, when run, print the following to stdout: + + +Hello, World! + + +You should finish with the following shell command: + + +python main.py + +`, + }, + } ) func main() { @@ -19,49 +52,60 @@ func main() { func run( ctx context.Context, ) error { - key := os.Getenv("GROQ_KEY") + groqKey := os.Getenv("GROQ_KEY") e2bKey := os.Getenv("E2B_API_KEY") - client, err := groq.NewClient(key) + client, err := groq.NewClient(groqKey) if err != nil { return err } - sb, err := e2b.NewSandbox( - ctx, - e2bKey, - ) + sb, err := e2b.NewSandbox(ctx, e2bKey) if err != nil { return err } - chat, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: ` - -Given the tools given to you, create a golang project with the following files: - - -main.go -utils.go - - -The main function should call the "utils.run() error" function. - -The project should, when run, print the following: - - -Hello, World! - -`, - }, - }, - MaxTokens: 2000, - Tools: sb.GetTools(), - }) - if err != nil { - return err + defer func() { + err := sb.Stop(ctx) + if err != nil { + fmt.Println(err) + } + }() + ts := sb.GetTools() + ts = append(ts, tools.Tool{ + Type: tools.ToolTypeFunction, + Function: tools.FunctionDefinition{ + Name: "complete", + Description: "Signify that the assigned task is complete.", + Parameters: tools.FunctionParameters{ + Type: "object", + Properties: map[string]tools.PropertyDefinition{ + "task": { + Type: "string", + Description: "The task that is complete.", + }}}}}) + for { + chat, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3Groq8B8192ToolUsePreview, + Messages: history, + MaxTokens: 3000, + Tools: ts, + }) + if err != nil { + return err + } + if chat.Choices[0].FinishReason == groq.ReasonFunctionCall { + if chat.Choices[0].Message.FunctionCall.Name == "complete" { + break + } + } + resp, err := sb.RunTooling(ctx, chat) + if err != nil { + history = append(history, + groq.ChatCompletionMessage{ + Role: groq.ChatMessageRoleUser, + Content: err.Error(), + }) + continue + } + history = append(history, resp...) } - fmt.Println(chat.Choices[0].Message.Content) return nil } diff --git a/examples/json-chat/main.go b/examples/json-chat/main.go index 7071fcf..eec88c7 100644 --- a/examples/json-chat/main.go +++ b/examples/json-chat/main.go @@ -10,6 +10,7 @@ import ( "os" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) func main() { @@ -34,7 +35,7 @@ func run( } resp := &Responses{} err = client.CreateChatCompletionJSON(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, diff --git a/examples/llava-blind/main.go b/examples/llava-blind/main.go index 0fea70a..a415874 100644 --- a/examples/llava-blind/main.go +++ b/examples/llava-blind/main.go @@ -7,6 +7,7 @@ import ( "os" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) func main() { @@ -29,7 +30,7 @@ func run( response, err := client.CreateChatCompletion( ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlavaV157B4096Preview, + Model: models.ModelLlavaV157B4096Preview, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, diff --git a/examples/moderation/main.go b/examples/moderation/main.go index 4f6a051..bf62c71 100644 --- a/examples/moderation/main.go +++ b/examples/moderation/main.go @@ -8,6 +8,7 @@ import ( "os" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) func main() { @@ -26,15 +27,15 @@ func run( if err != nil { return err } - response, err := client.Moderate(ctx, groq.ModerationRequest{ - Model: groq.ModelLlamaGuard38B, - Messages: []groq.ChatCompletionMessage{ + response, err := client.Moderate(ctx, + []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, Content: "I want to kill them.", }, }, - }) + models.ModelLlamaGuard38B, + ) if err != nil { return err } diff --git a/examples/toolhouse-python-code-interpreter/main.go b/examples/toolhouse-python-code-interpreter/main.go index 375abb6..17f8130 100644 --- a/examples/toolhouse-python-code-interpreter/main.go +++ b/examples/toolhouse-python-code-interpreter/main.go @@ -12,6 +12,8 @@ import ( "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/toolhouse" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/test" ) func main() { @@ -23,7 +25,11 @@ func main() { } func run(ctx context.Context) error { - ext, err := toolhouse.NewExtension(os.Getenv("TOOLHOUSE_API_KEY"), + toolhouseKey, err := test.GetAPIKey("TOOLHOUSE_API_KEY") + if err != nil { + return err + } + ext, err := toolhouse.NewExtension(toolhouseKey, toolhouse.WithMetadata(map[string]any{ "id": "conner", "timezone": 5, @@ -31,7 +37,11 @@ func run(ctx context.Context) error { if err != nil { return err } - client, err := groq.NewClient(os.Getenv("GROQ_KEY")) + groqKey, err := test.GetAPIKey("GROQ_KEY") + if err != nil { + return err + } + client, err := groq.NewClient(groqKey) if err != nil { return err } @@ -41,25 +51,27 @@ func run(ctx context.Context) error { Content: "Write a python function to print the first 10 prime numbers containing the number 3 then respond with the answer. DO NOT GUESS WHAT THE OUTPUT SHOULD BE. MAKE SURE TO CALL THE TOOL GIVEN.", }, } - print(history[len(history)-1].Content) + tools, err := ext.GetTools(ctx) + if err != nil { + return err + } re, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: history, - Tools: ext.MustGetTools(ctx), + Tools: tools, ToolChoice: "required", }) if err != nil { return fmt.Errorf("failed to create 1 chat completion: %w", err) } history = append(history, re.Choices[0].Message) - print(history[len(history)-1].ToolCalls[len(history[len(history)-1].ToolCalls)-1].Function.Arguments) r, err := ext.Run(ctx, re) if err != nil { return fmt.Errorf("failed to run tool: %w", err) } history = append(history, r...) finalr, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: history, MaxTokens: 2000, }) @@ -67,7 +79,6 @@ func run(ctx context.Context) error { return fmt.Errorf("failed to create 2 chat completion: %w", err) } history = append(history, finalr.Choices[0].Message) - print(history[len(history)-1].Content) jsnHistory, err := json.MarshalIndent(history, "", " ") if err != nil { return fmt.Errorf("failed to marshal history: %w", err) diff --git a/examples/vhdl-documentor-json/main.go b/examples/vhdl-documentor-json/main.go index 5e5d556..dd65a59 100644 --- a/examples/vhdl-documentor-json/main.go +++ b/examples/vhdl-documentor-json/main.go @@ -14,6 +14,7 @@ import ( "time" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" ) var ( @@ -68,7 +69,7 @@ func run( err = client.CreateChatCompletionJSON( ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: []groq.ChatCompletionMessage{{ Role: groq.ChatMessageRoleSystem, Content: prompt, diff --git a/extensions/composio/.go-version b/extensions/composio/.go-version deleted file mode 100644 index 49e0a31..0000000 --- a/extensions/composio/.go-version +++ /dev/null @@ -1 +0,0 @@ -1.23.1 diff --git a/extensions/composio/auth.go b/extensions/composio/auth.go index 4496185..74db49d 100644 --- a/extensions/composio/auth.go +++ b/extensions/composio/auth.go @@ -11,8 +11,8 @@ import ( ) type ( - // Auther is an interface for composio auth. - Auther interface { + // Authorizer is an interface for composio auth. + Authorizer interface { GetConnectedAccounts(ctx context.Context, opts ...AuthOption) ([]ConnectedAccount, error) } // ConnectedAccount represents a composio connected account. diff --git a/extensions/composio/auth_test.go b/extensions/composio/auth_test.go index f71d7ce..4cc8f8f 100644 --- a/extensions/composio/auth_test.go +++ b/extensions/composio/auth_test.go @@ -51,7 +51,7 @@ func TestAuth(t *testing.T) { // TestUnitGetConnectedAccounts is an Unit test using a real composio server and api key. func TestUnitGetConnectedAccounts(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) diff --git a/extensions/composio/composio.go b/extensions/composio/composio.go index 039150d..ea28d66 100644 --- a/extensions/composio/composio.go +++ b/extensions/composio/composio.go @@ -14,15 +14,7 @@ const ( composioBaseURL = "https://backend.composio.dev/api" ) -var _ Composer = &Composio{} - type ( - // Composer is an interface for composio client. - Composer interface { - Tooler - Runner - Auther - } // Composio is a composio client. Composio struct { apiKey string diff --git a/extensions/composio/run.go b/extensions/composio/run.go index 73585a7..fe6285f 100644 --- a/extensions/composio/run.go +++ b/extensions/composio/run.go @@ -13,7 +13,9 @@ import ( type ( // Runner is an interface for composio run. Runner interface { - Run(ctx context.Context, response groq.ChatCompletionResponse) ( + Run(ctx context.Context, + user ConnectedAccount, + response groq.ChatCompletionResponse) ( []groq.ChatCompletionMessage, error) } request struct { @@ -29,21 +31,18 @@ type ( // Run runs the composio client on a chat completion response. func (c *Composio) Run( ctx context.Context, + user ConnectedAccount, response groq.ChatCompletionResponse, ) ([]groq.ChatCompletionMessage, error) { var respH []groq.ChatCompletionMessage - if response.Choices[0].FinishReason != groq.FinishReasonFunctionCall && + if response.Choices[0].FinishReason != groq.ReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { - return nil, fmt.Errorf("Not a function call") - } - connectedAccount, err := c.GetConnectedAccounts(ctx, WithShowActiveOnly(true)) - if err != nil { - return nil, err + return nil, fmt.Errorf("not a function call") } for _, toolCall := range response.Choices[0].Message.ToolCalls { var args map[string]any if json.Valid([]byte(toolCall.Function.Arguments)) { - err = json.Unmarshal([]byte(toolCall.Function.Arguments), &args) + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args) if err != nil { return nil, err } @@ -55,11 +54,10 @@ func (c *Composio) Run( http.MethodPost, fmt.Sprintf("%s/v2/actions/%s/execute", c.baseURL, toolCall.Function.Name), builders.WithBody(&request{ - ConnectedAccountID: connectedAccount[0].ID, + ConnectedAccountID: user.ID, EntityID: "default", AppName: toolCall.Function.Name, Input: args, - Text: "", AuthConfig: map[string]any{}, }), ) diff --git a/extensions/composio/run_test.go b/extensions/composio/run_test.go index 6db1368..a013084 100644 --- a/extensions/composio/run_test.go +++ b/extensions/composio/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/composio" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/test" "github.com/conneroisu/groq-go/pkg/tools" "github.com/stretchr/testify/assert" @@ -53,7 +54,9 @@ func TestRun(t *testing.T) { composio.WithBaseURL(testS.URL), ) a.NoError(err) - resp, err := client.Run(ctx, groq.ChatCompletionResponse{ + ca, err := client.GetConnectedAccounts(ctx, composio.WithShowActiveOnly(true)) + a.NoError(err) + resp, err := client.Run(ctx, ca[0], groq.ChatCompletionResponse{ Choices: []groq.ChatCompletionChoice{{ Message: groq.ChatCompletionMessage{ Role: groq.ChatMessageRoleUser, @@ -63,14 +66,14 @@ func TestRun(t *testing.T) { Name: "TOOL", Arguments: `{ "foo": "bar", }`, }}}}, - FinishReason: groq.FinishReasonFunctionCall, + FinishReason: groq.ReasonFunctionCall, }}}) a.NoError(err) assert.Equal(t, "response1", resp[0].Content) } func TestUnitRun(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) @@ -91,7 +94,7 @@ func TestUnitRun(t *testing.T) { ) a.NoError(err, "NewClient error") response, err := groqClient.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq8B8192ToolUsePreview, + Model: models.ModelLlama3Groq8B8192ToolUsePreview, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -103,7 +106,9 @@ func TestUnitRun(t *testing.T) { }) a.NoError(err) a.NotEmpty(response.Choices[0].Message.ToolCalls) - resp2, err := client.Run(ctx, response) + users, err := client.GetConnectedAccounts(ctx) + a.NoError(err) + resp2, err := client.Run(ctx, users[0], response) a.NoError(err) a.NotEmpty(resp2) t.Logf("%+v\n", resp2) diff --git a/extensions/composio/tools.go b/extensions/composio/tools.go index 476ecfa..4b05b38 100644 --- a/extensions/composio/tools.go +++ b/extensions/composio/tools.go @@ -11,32 +11,61 @@ import ( ) type ( - // Tooler is an interface for retreiving composio tools. - Tooler interface { - GetTools(ctx context.Context, opts ...ToolsOption) ( - []tools.Tool, error) - } // Tool represents a composio tool as returned by the api. Tool struct { - Name string `json:"name"` - Enum string `json:"enum"` - Tags []string `json:"tags"` - Logo string `json:"logo"` - AppID string `json:"appId"` - AppName string `json:"appName"` - DisplayName string `json:"displayName"` - Description string `json:"description"` - Parameters tools.FunctionParameters `json:"parameters"` - Response struct { + // Name is the name of the tool returned by the composio api. + Name string `json:"name"` + // Enum is the enum of the tool returned by the composio api. + Enum string `json:"enum"` + // Tags are the tags of the tool returned by the composio api. + Tags []string `json:"tags"` + // Logo is the logo of the tool returned by the composio api. + Logo string `json:"logo"` + // AppID is the app id of the tool returned by the composio api. + AppID string `json:"appId"` + // AppName is the app name of the tool returned by the composio + // api. + AppName string `json:"appName"` + // DisplayName is the display name of the tool returned by the + // composio api. + DisplayName string `json:"displayName"` + // Description is the description of the tool returned by the + // composio api. + Description string `json:"description"` + // Parameters are the parameters of the tool returned by the + // composio api. + Parameters tools.FunctionParameters `json:"parameters"` + // Response is the response of the tool returned by the + // composio api. + Response struct { + // Properties are the properties of the response + // returned by the composio api. Properties struct { + // Data is the data of the response returned by + // the composio api. Data struct { + // Title is the title of the data in the + // response returned by the composio + // api. Title string `json:"title"` - Type string `json:"type"` + // Type is the type of the data in the + // response returned by the composio + // api. + Type string `json:"type"` } `json:"data"` + // Successful is the successful response of the + // composio api. Successful struct { + // Description is the description of the + // successful response of the composio + // api. Description string `json:"description"` - Title string `json:"title"` - Type string `json:"type"` + // Title is the title of the successful + // response of the composio api. + Title string `json:"title"` + // Type is the type of the successful + // response of the composio api. + Type string `json:"type"` } `json:"successful"` Error struct { AnyOf []struct { diff --git a/extensions/composio/tools_test.go b/extensions/composio/tools_test.go index 30be6bf..3efc101 100644 --- a/extensions/composio/tools_test.go +++ b/extensions/composio/tools_test.go @@ -48,7 +48,7 @@ func TestGetTools(t *testing.T) { // TestUnitGetTools tests the ability of the composio client to get tools. func TestUnitGetTools(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) diff --git a/extensions/e2b/model.go b/extensions/e2b/model.go index acecbed..80e26c7 100644 --- a/extensions/e2b/model.go +++ b/extensions/e2b/model.go @@ -7,7 +7,8 @@ import ( ) type ( - // Receiver is an interface for a constantly receiving instance. + // Receiver is an interface for a constantly receiving instance that + // can closed. // // Implementations should be conccurent safe. Receiver interface { @@ -25,7 +26,10 @@ type ( // // If the context is cancelled before requesting the timeout, // the error will be ctx.Err(). - KeepAlive(ctx context.Context, timeout time.Duration) error + KeepAlive( + ctx context.Context, + timeout time.Duration, + ) error // NewProcess creates a new process. NewProcess( cmd string, diff --git a/extensions/e2b/options.go b/extensions/e2b/options.go index 4a754b9..03f7690 100644 --- a/extensions/e2b/options.go +++ b/extensions/e2b/options.go @@ -5,7 +5,7 @@ import ( "net/http" ) -// E2B Options +// E2B Sandbox Options // WithBaseURL sets the base URL for the e2b sandbox. func WithBaseURL(baseURL string) Option { @@ -37,7 +37,7 @@ func WithCwd(cwd string) Option { return func(s *Sandbox) { s.Cwd = cwd } } -// WithWsURL sets the websocket url for the e2b sandbox. +// WithWsURL sets the websocket url resolving function for the e2b sandbox. func WithWsURL(wsURL func(s *Sandbox) string) Option { return func(s *Sandbox) { s.wsURL = wsURL } } diff --git a/extensions/e2b/sandbox.go b/extensions/e2b/sandbox.go index 480204e..4e6d235 100644 --- a/extensions/e2b/sandbox.go +++ b/extensions/e2b/sandbox.go @@ -102,14 +102,7 @@ type ( Message string `json:"message"` // Message is the message of the error. } // Method is a JSON-RPC method. - Method string - decResp struct { - Method string `json:"method"` - ID int `json:"id"` - Params struct { - Subscription string `json:"subscription"` - } - } + Method string ) const ( @@ -485,21 +478,18 @@ func (p *Process) Done() <-chan struct{} { } // SubscribeStdout subscribes to the process's stdout. -func (p *Process) SubscribeStdout(events chan Event) (err error) { - err = p.subscribe(p.ctx, OnStdout, events) - return +func (p *Process) SubscribeStdout() (chan Event, chan error) { + return p.subscribe(p.ctx, OnStdout) } // SubscribeStderr subscribes to the process's stderr. -func (p *Process) SubscribeStderr(events chan Event) (err error) { - err = p.subscribe(p.ctx, OnStderr, events) - return +func (p *Process) SubscribeStderr() (chan Event, chan error) { + return p.subscribe(p.ctx, OnStderr) } // SubscribeExit subscribes to the process's exit. -func (p *Process) SubscribeExit(events chan Event) (err error) { - err = p.subscribe(p.ctx, OnExit, events) - return +func (p *Process) SubscribeExit() (chan Event, chan error) { + return p.subscribe(p.ctx, OnExit) } // Subscribe subscribes to a process event. @@ -508,20 +498,19 @@ func (p *Process) SubscribeExit(events chan Event) (err error) { func (p *Process) subscribe( ctx context.Context, event ProcessEvents, - eCh chan<- Event, -) error { - errCh := make(chan error) +) (chan Event, chan error) { + events := make(chan Event) + errs := make(chan error) go func(errCh chan error) { respCh := make(chan []byte) defer close(respCh) err := p.sb.writeRequest(ctx, processSubscribe, []any{event, p.id}, respCh) - if err != nil || respCh == nil { + if err != nil { errCh <- err } res, err := decodeResponse[string, any](<-respCh) - errCh <- err if err != nil { - return + errCh <- err } p.sb.Map.Store(res.Result, respCh) for { @@ -538,7 +527,7 @@ func (p *Process) subscribe( p.sb.logger.Debug("subscription id mismatch", "expected", res.Result, "got", event.Params.Subscription) continue } - eCh <- event + events <- event case <-ctx.Done(): p.sb.Map.Delete(res.Result) finishCtx, cancel := context.WithCancel(context.Background()) @@ -554,8 +543,8 @@ func (p *Process) subscribe( return } } - }(errCh) - return <-errCh + }(errs) + return events, errs } func (s *Sandbox) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json") @@ -608,6 +597,13 @@ func (s *Sandbox) identify(ctx context.Context) { } } func (s *Sandbox) read(ctx context.Context) (err error) { + type decResp struct { + Method string `json:"method"` + ID int `json:"id"` + Params struct { + Subscription string `json:"subscription"` + } + } var body []byte defer func() { err = s.ws.Close() diff --git a/extensions/e2b/sandbox_test.go b/extensions/e2b/sandbox_test.go index 385ebf1..02806ab 100644 --- a/extensions/e2b/sandbox_test.go +++ b/extensions/e2b/sandbox_test.go @@ -164,11 +164,13 @@ func TestNewSandbox(t *testing.T) { err = proc.Start(ctx) a.NoError(err) - e := make(chan Event) - err = proc.SubscribeStdout(e) - a.NoError(err) - event := <-e - jsnBytes, err := json.MarshalIndent(&event, "", " ") - a.NoError(err) - t.Logf("test got event: %s", string(jsnBytes)) + e, errCh := proc.SubscribeStdout() + select { + case <-errCh: + t.Fatal("got error from SubscribeStdout") + case event := <-e: + jsnBytes, err := json.MarshalIndent(&event, "", " ") + a.NoError(err) + t.Logf("test got event: %s", string(jsnBytes)) + } } diff --git a/extensions/e2b/tools.go b/extensions/e2b/tools.go index 9ff4c0c..3148e2e 100644 --- a/extensions/e2b/tools.go +++ b/extensions/e2b/tools.go @@ -131,12 +131,11 @@ var ( if err != nil { return groq.ChatCompletionMessage{}, err } - e := make(chan Event, 10) - err = proc.SubscribeStdout(e) + e, errCh := proc.SubscribeStdout() if err != nil { return groq.ChatCompletionMessage{}, err } - err = proc.SubscribeStderr(e) + e2, errCh := proc.SubscribeStderr() if err != nil { return groq.ChatCompletionMessage{}, err } @@ -152,8 +151,14 @@ var ( return case event := <-e: buf.Write([]byte(event.Params.Result.Line)) + continue + case event := <-e2: + buf.Write([]byte(event.Params.Result.Line)) + continue + case <-errCh: + return case <-proc.Done(): - break + return } } }() @@ -267,8 +272,8 @@ func (s *Sandbox) RunTooling( ctx context.Context, response groq.ChatCompletionResponse, ) ([]groq.ChatCompletionMessage, error) { - if response.Choices[0].FinishReason != groq.FinishReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { - return nil, fmt.Errorf("not a function call") + if response.Choices[0].FinishReason != groq.ReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { + return nil, fmt.Errorf("not a function call: %v", response.Choices[0].FinishReason) } respH := []groq.ChatCompletionMessage{} for _, tool := range response.Choices[0].Message.ToolCalls { diff --git a/extensions/e2b/tools_test.go b/extensions/e2b/tools_test.go index 8b69a23..7a7fd40 100644 --- a/extensions/e2b/tools_test.go +++ b/extensions/e2b/tools_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" ) @@ -19,7 +20,7 @@ func getapiKey(t *testing.T, val string) string { } func TestSandboxTooling(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) @@ -35,8 +36,8 @@ func TestSandboxTooling(t *testing.T) { a.NoError(err, "NewClient error") tools := sb.GetTools() // ask the ai to create a file with the data "Hello World!" in file "hello.txt" - response := client.MustCreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -52,6 +53,7 @@ NOTE: You are in the correct cwd. Just call the write tool with a name of hello. MaxTokens: 2000, Tools: tools, }) + a.NoError(err) sb.logger.Debug("response from model", "response", response) resps, err := sb.RunTooling(ctx, response) a.NoError(err) diff --git a/extensions/e2b/unit_test.go b/extensions/e2b/unit_test.go index 311d9d2..0346a24 100644 --- a/extensions/e2b/unit_test.go +++ b/extensions/e2b/unit_test.go @@ -3,6 +3,7 @@ package e2b_test import ( "context" "encoding/json" + "fmt" "os" "testing" "time" @@ -21,7 +22,7 @@ func getapiKey(t *testing.T) string { } func TestPostSandbox(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) @@ -52,7 +53,7 @@ func TestPostSandbox(t *testing.T) { // TestWriteRead tests the Write and Read methods of the Sandbox. func TestWriteRead(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } filePath := "test.txt" @@ -78,7 +79,7 @@ func TestWriteRead(t *testing.T) { } func TestCreateProcess(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) @@ -102,20 +103,24 @@ func TestCreateProcess(t *testing.T) { a.NoError(err, "could not create process") err = proc.Start(ctx) a.NoError(err) - stdOutEvents := make(chan e2b.Event) - err = proc.SubscribeStdout(stdOutEvents) + stdOutEvents, errCh := proc.SubscribeStdout() a.NoError(err) - event := <-stdOutEvents - jsonBytes, err := json.MarshalIndent(&event, "", " ") - if err != nil { - a.Error(err) - return + select { + case <-errCh: + t.Fatal(fmt.Errorf("failed to subscribe to stdout: %w", err)) + case event := <-stdOutEvents: + jsonBytes, err := json.MarshalIndent(&event, "", " ") + if err != nil { + a.Error(err) + return + } + t.Logf("test got event: %s", string(jsonBytes)) + break } - t.Logf("test got event: %s", string(jsonBytes)) } func TestFilesystemSubscribe(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) @@ -150,7 +155,7 @@ func TestFilesystemSubscribe(t *testing.T) { } func TestKeepAlive(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } a := assert.New(t) diff --git a/extensions/jigsawstack/README.md b/extensions/jigsawstack/README.md new file mode 100644 index 0000000..c42b33d --- /dev/null +++ b/extensions/jigsawstack/README.md @@ -0,0 +1 @@ +# jigsawstack diff --git a/extensions/jigsawstack/audio.go b/extensions/jigsawstack/audio.go new file mode 100644 index 0000000..7aad9f4 --- /dev/null +++ b/extensions/jigsawstack/audio.go @@ -0,0 +1,87 @@ +package jigsawstack + +import ( + "context" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + ttsEndpoint Endpoint = "/v1/ai/tts" + accentsEndpoint Endpoint = "/v1/audio/speaker_voice_accents" +) + +type ( + // TTSOption is an option for the TTS request. + TTSOption func(*ttsRequest) + // ttsRequest represents a request structure for TTS API. + ttsRequest struct { + // Text is the text to convert to speech. + // Required. + Text string `json:"text"` + // Accent is the accent of the speaker voice to use. + // + // Not required if the FileKey or SpeakerURL is not provided. + Accent string `json:"accent,omitempty"` + // SpeakerURL is the url of the speaker voice to use. + // + // Not required if the FileKey is not provided. + SpeakerURL string `json:"speaker_clone_url,omitempty"` + // FileKey is the key of the file to use as the speaker voice. + // + // Not required if the SpeakerURL is not provided. + FileKey string `json:"speaker_clone_file_store_key,omitempty"` + } +) + +// WithAccent sets the accent of the speaker voice to use. +func WithAccent(accent string) TTSOption { + return func(r *ttsRequest) { r.Accent = accent } +} + +// WithSpeakerURL sets the url of the speaker voice to use. +func WithSpeakerURL(url string) TTSOption { + return func(r *ttsRequest) { r.SpeakerURL = url } +} + +// WithFileKey sets the file key of the speaker voice to use. +func WithFileKey(key string) TTSOption { + return func(r *ttsRequest) { r.FileKey = key } +} + +// AudioTTS creates a text to speech (TTS) audio file. +// +// It only support one option at a time, but does support no options. +// +// POST https://api.jigsawstack.com/v1/ai/tts +// +// https://docs.jigsawstack.com/api-reference/ai/text-to-speech +func (j *JigsawStack) AudioTTS( + ctx context.Context, + text string, + options ...TTSOption, +) (mp3 string, err error) { + body := ttsRequest{ + Text: text, + } + for _, option := range options { + option(&body) + } + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(ttsEndpoint), + builders.WithBody(body), + ) + if err != nil { + return "", err + } + var resp string + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + return resp, nil +} diff --git a/extensions/jigsawstack/audio_test.go b/extensions/jigsawstack/audio_test.go new file mode 100644 index 0000000..3bffa7a --- /dev/null +++ b/extensions/jigsawstack/audio_test.go @@ -0,0 +1,39 @@ +package jigsawstack_test + +import ( + "context" + "io" + "os" + "strings" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +func TestAudioTTS(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip() + } + a := assert.New(t) + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + ctx := context.Background() + client, err := jigsawstack.NewJigsawStack( + apiKey, + jigsawstack.WithLogger(test.DefaultLogger), + ) + a.NoError(err) + response, err := client.AudioTTS(ctx, + "Hello, world! Welcome to Groq!", + jigsawstack.WithAccent("zh-TW-female-19"), + ) + a.NoError(err) + // write the io.reader to a file + f, err := os.Create("tts.mp3") + a.NoError(err) + defer f.Close() + _, err = io.Copy(f, strings.NewReader(response)) + a.NoError(err) +} diff --git a/extensions/jigsawstack/doc.go b/extensions/jigsawstack/doc.go new file mode 100644 index 0000000..92b7a64 --- /dev/null +++ b/extensions/jigsawstack/doc.go @@ -0,0 +1,4 @@ +// Package jigsawstack provides a JigsawStack extension for groq-go. +// +// It gives tools for working with the JigsawStack API. +package jigsawstack diff --git a/extensions/jigsawstack/geography.go b/extensions/jigsawstack/geography.go new file mode 100644 index 0000000..3bd4c62 --- /dev/null +++ b/extensions/jigsawstack/geography.go @@ -0,0 +1,147 @@ +package jigsawstack + +import ( + "context" + "net/http" + "net/url" + "strconv" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + geographyEndpoint Endpoint = "/v1/geo/search" +) + +type ( + // GeographyRequest represents a request structure for geography API. + GeographyRequest struct { + Query string `json:"query"` + Country string `json:"country,omitempty"` + Latitude float64 `json:"latitude,omitempty"` + ProximityLat float64 `json:"proximity_lat,omitempty"` + Longitude float64 `json:"longitude,omitempty"` + ProximityLng float64 `json:"proximity_lng,omitempty"` + Types string `json:"types,omitempty"` + } + // GeographyResponse represents a response structure for geography API. + GeographyResponse struct { + Success bool `json:"success"` + Data []struct { + Type string `json:"type"` + FullAddress string `json:"full_address"` + Name string `json:"name"` + PlaceFormatted string `json:"place_formatted"` + Postcode string `json:"postcode"` + Place string `json:"place"` + Region struct { + Name string `json:"name"` + RegionCode string `json:"region_code"` + RegionCodeFull string `json:"region_code_full"` + } `json:"region"` + Country struct { + Name string `json:"name"` + CountryCode string `json:"country_code"` + CountryCodeAlpha3 string `json:"country_code_alpha_3"` + } `json:"country"` + Language string `json:"language"` + Geoloc struct { + Type string `json:"type"` + Coordinates []float64 `json:"coordinates"` + } `json:"geoloc"` + PoiCategory []string `json:"poi_category"` + AddtionalProperties struct { + Phone string `json:"phone"` + Website string `json:"website"` + OpenHours struct { + } `json:"open_hours"` + } `json:"addtional_properties"` + } `json:"data"` + } +) + +// URLQuery converts the params into params on the given url. +func (r *GeographyRequest) URLQuery(url *url.URL) { + values := url.Query() + if r.Query != "" { + values.Set("query", r.Query) + } + if r.Country != "" { + values.Set("country", r.Country) + } + var strLat, strLng string + if r.Latitude != 0 { + strLat = strconv.FormatFloat(r.Latitude, 'f', -1, 64) + values.Set("latitude", strLat) + } + if r.ProximityLat != 0 { + strLat = strconv.FormatFloat(r.ProximityLat, 'f', -1, 64) + values.Set("proximity_lat", strLat) + } + if r.Longitude != 0 { + strLng = strconv.FormatFloat(r.Longitude, 'f', -1, 64) + values.Set("longitude", strLng) + } + if r.ProximityLng != 0 { + strLng = strconv.FormatFloat(r.ProximityLng, 'f', -1, 64) + values.Set("proximity_lng", strLng) + } + if r.Types != "" { + values.Set("types", r.Types) + } + url.RawQuery = values.Encode() +} + +// GeographySearch performs a geography search api call over a query string. +// +// https://api.jigsawstack.com/v1/geo/search +// +// https://docs.jigsawstack.com/api-reference/geo/search +func (j *JigsawStack) GeographySearch( + ctx context.Context, + request GeographyRequest, +) (response GeographyResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(geographyEndpoint), + builders.WithQuerier(&request), + ) + if err != nil { + return + } + var resp GeographyResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// GeographyGeocode performs a geography geocode api call over a query string. +// +// GET https://api.jigsawstack.com/v1/geo/geocode +// +// https://docs.jigsawstack.com/api-reference/geo/geocode +func (j *JigsawStack) GeographyGeocode( + ctx context.Context, + request GeographyRequest, +) (response GeographyResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + j.baseURL+string(geographyEndpoint), + builders.WithQuerier(&request), + ) + if err != nil { + return + } + var resp GeographyResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/jigsawstack.go b/extensions/jigsawstack/jigsawstack.go new file mode 100644 index 0000000..3f16a6d --- /dev/null +++ b/extensions/jigsawstack/jigsawstack.go @@ -0,0 +1,103 @@ +package jigsawstack + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + defaultBaseURL = "https://api.jigsawstack.com" +) + +type ( + // JigsawStack is a JigsawStack extension. + JigsawStack struct { + baseURL string + client *http.Client + logger *slog.Logger + header builders.Header + } + // Option is an option for the JigsawStack client. + Option func(*JigsawStack) + // Endpoint is the endpoint for the JigsawStack api. + Endpoint string +) + +// NewJigsawStack creates a new JigsawStack extension. +func NewJigsawStack(apiKey string, opts ...Option) (*JigsawStack, error) { + j := &JigsawStack{ + baseURL: defaultBaseURL, + client: http.DefaultClient, + logger: slog.Default(), + } + for _, opt := range opts { + opt(j) + } + j.header.SetCommonHeaders = func(req *http.Request) { + req.Header.Set("x-api-key", apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + } + return j, nil +} + +// WithBaseURL sets the base URL for the JigsawStack extension. +func WithBaseURL(baseURL string) Option { + return func(j *JigsawStack) { j.baseURL = baseURL } +} + +// WithClient sets the client for the JigsawStack extension. +func WithClient(client *http.Client) Option { + return func(j *JigsawStack) { j.client = client } +} + +// WithLogger sets the logger for the JigsawStack extension. +func WithLogger(logger *slog.Logger) Option { + return func(j *JigsawStack) { j.logger = logger } +} + +func (j *JigsawStack) sendRequest(req *http.Request, v any) error { + j.header.SetCommonHeaders(req) + resp, err := j.client.Do(req) + if err != nil { + return err + } + j.logger.Debug("received http response from jigsawstack", "status", resp.Status, "url", req.URL) + defer resp.Body.Close() + if resp.StatusCode < http.StatusOK || + resp.StatusCode >= http.StatusBadRequest { + read, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("bad status code: %d\nbdy: %s\n headers: %v", resp.StatusCode, read, resp.Header) + } + if v == nil { + return nil + } + switch o := v.(type) { + case *string: + b, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + *o = string(b) + return nil + default: + err = json.NewDecoder(resp.Body).Decode(v) + if err != nil { + read, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + j.logger.Debug("failed to decode response", "response", string(read)) + return fmt.Errorf("failed to decode response: %w\nbody: %s", err, string(read)) + } + return nil + } +} diff --git a/extensions/jigsawstack/jigsawstack_test.go b/extensions/jigsawstack/jigsawstack_test.go new file mode 100644 index 0000000..3e0126a --- /dev/null +++ b/extensions/jigsawstack/jigsawstack_test.go @@ -0,0 +1 @@ +package jigsawstack diff --git a/extensions/jigsawstack/natural_language.go b/extensions/jigsawstack/natural_language.go new file mode 100644 index 0000000..d4efc69 --- /dev/null +++ b/extensions/jigsawstack/natural_language.go @@ -0,0 +1,180 @@ +package jigsawstack + +import ( + "context" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + summaryEndpoint Endpoint = "/v1/ai/summarize" + sentimentSuffix Endpoint = "/v1/ai/sentiment" + translateEndpoint Endpoint = "/v1/ai/translate" + + // EmotionAnger is the anger emotion. + EmotionAnger Emotion = "anger" + // EmotionFear is the fear emotion. + EmotionFear Emotion = "fear" + // EmotionSadness is the sadness emotion. + EmotionSadness Emotion = "sadness" + // EmotionHappiness is the happiness emotion. + EmotionHappiness Emotion = "happiness" + // EmotionAnxiety is the anxiety emotion. + EmotionAnxiety Emotion = "anxiety" + // EmotionDisgust is the disgust emotion. + EmotionDisgust Emotion = "disgust" + // EmotionEmbarrassment is the embarrassment emotion. + EmotionEmbarrassment Emotion = "embarrassment" + // EmotionLove is the love emotion. + EmotionLove Emotion = "love" + // EmotionSurprise is the surprise emotion. + EmotionSurprise Emotion = "surprise" + // EmotionShame is the shame emotion. + EmotionShame Emotion = "shame" + // EmotionEnvy is the envy emotion. + EmotionEnvy Emotion = "envy" + // EmotionSatisfaction is the satisfaction emotion. + EmotionSatisfaction Emotion = "satisfaction" + // EmotionSelfConfidence is the self-confidence emotion. + EmotionSelfConfidence Emotion = "self-confidence" + // EmotionAnnoyance is the annoyance emotion. + EmotionAnnoyance Emotion = "annoyance" + // EmotionBoredom is the boredom emotion. + EmotionBoredom Emotion = "boredom" + // EmotionHatred is the hatred emotion. + EmotionHatred Emotion = "hatred" + // EmotionCompassion is the compassion emotion. + EmotionCompassion Emotion = "compassion" + // EmotionGuilt is the guilt emotion. + EmotionGuilt Emotion = "guilt" + // EmotionLoneliness is the loneliness emotion. + EmotionLoneliness Emotion = "loneliness" + // EmotionDepression is the depression emotion. + EmotionDepression Emotion = "depression" + // EmotionPride is the pride emotion. + EmotionPride Emotion = "pride" + // EmotionNeutral is the neutral emotion. + EmotionNeutral Emotion = "neutral" +) + +type ( + // Language is a language. + Language string + // TranslateRequest represents a request structure for translate API. + TranslateRequest struct { + CurrentLanguage Language `json:"current_language"` + TargetLanguage Language `json:"target_language"` + Text string `json:"text"` + } + // TranslateResponse represents a response structure for translate API. + TranslateResponse struct { + Success bool `json:"success"` + TranslatedText string `json:"translated_text"` + } + // Emotion is an emotion. + Emotion string + // SentimentResponse represents a response structure for sentiment API. + SentimentResponse struct { + Success bool `json:"success"` + Sentiment struct { + Emotion Emotion `json:"emotion"` + Sentiment string `json:"sentiment"` + Score float64 `json:"score"` + Sentences []struct { + Text string `json:"text"` + Emotion Emotion `json:"emotion"` + Sentiment string `json:"sentiment"` + Score float64 `json:"score"` + } `json:"sentences"` + } `json:"sentiment"` + } +) + +// Sentiment performs a sentiment api call over a string. +func (j *JigsawStack) Sentiment( + ctx context.Context, + text string, +) (SentimentResponse, error) { + var request = struct { + Text string `json:"text"` + }{Text: text} + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(sentimentSuffix), + builders.WithBody(request), + ) + if err != nil { + return SentimentResponse{}, err + } + var respH SentimentResponse + err = j.sendRequest(req, &respH) + if err != nil { + return SentimentResponse{}, err + } + return respH, nil +} + +type ( + // SummaryRequest represents a request structure for summary API. + SummaryRequest struct { + Text string `json:"text"` + } + // SummaryResponse represents a response structure for summary API. + SummaryResponse struct { + Success bool `json:"success"` + Summary string `json:"summary"` + } +) + +// Summarize summarizes the give text. +// +// Max text character is 5000. +func (j *JigsawStack) Summarize( + ctx context.Context, + request SummaryRequest, +) (response SummaryResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(summaryEndpoint), + builders.WithBody(request), + ) + if err != nil { + return + } + var resp SummaryResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// Translate translates the text from the current language to the target language. +// +// Max text character is 5000. +func (j *JigsawStack) Translate( + ctx context.Context, + request TranslateRequest, +) (response TranslateResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(translateEndpoint), + builders.WithBody(request), + ) + if err != nil { + return + } + var resp TranslateResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/natural_language_test.go b/extensions/jigsawstack/natural_language_test.go new file mode 100644 index 0000000..4a1b32a --- /dev/null +++ b/extensions/jigsawstack/natural_language_test.go @@ -0,0 +1,26 @@ +package jigsawstack_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +func TestJigsawStack_Sentiment(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping integration test") + } + a := assert.New(t) + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.Sentiment(context.Background(), "I am a happy person") + a.NoError(err) + a.True(resp.Success) + a.Equal(jigsawstack.EmotionHappiness, resp.Sentiment.Emotion) + a.Equal("positive", resp.Sentiment.Sentiment) +} diff --git a/extensions/jigsawstack/prediction.go b/extensions/jigsawstack/prediction.go new file mode 100644 index 0000000..2f4ea15 --- /dev/null +++ b/extensions/jigsawstack/prediction.go @@ -0,0 +1,54 @@ +package jigsawstack + +import ( + "context" + "net/http" + "time" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + predictEndpoint Endpoint = "v1/ai/prediction" +) + +type ( + // DatasetEntry represents a dataset entry. + DatasetEntry struct { + Date time.Time `json:"date"` + Value float64 `json:"value"` + } + // PredictResponse represents a response structure for prediction API. + PredictResponse struct { + Success bool `json:"success"` + Answer []DatasetEntry `json:"answer"` + } +) + +// Predict predicts the future values of a dataset. +// +// Max text character is 5000. +func (j *JigsawStack) Predict( + ctx context.Context, + dataset []DatasetEntry, +) (response PredictResponse, err error) { + var predictRequest = struct { + Dataset []DatasetEntry `json:"dataset"` + }{Dataset: dataset} + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(predictEndpoint), + builders.WithBody(predictRequest), + ) + if err != nil { + return + } + var resp PredictResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/prompt.go b/extensions/jigsawstack/prompt.go new file mode 100644 index 0000000..270c645 --- /dev/null +++ b/extensions/jigsawstack/prompt.go @@ -0,0 +1,233 @@ +package jigsawstack + +import ( + "context" + "net/http" + "strconv" + "time" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + promptCreateEndpoint Endpoint = "/v1/prompt_engine" +) + +type ( + // PromptRunResponse represents a response structure for prompt run API. + PromptRunResponse struct { + Success bool `json:"success"` + Result string `json:"result"` + } + // PromptCreateInput represents an entry in a prompt create request. + PromptCreateInput struct { + Key string `json:"key"` + Optional bool `json:"optional"` + InitialValue string `json:"initial_value"` + } + // PromptCreateRequest represents a request structure for prompt create API. + PromptCreateRequest struct { + Prompt string `json:"prompt"` + Inputs []PromptCreateInput `json:"inputs"` + ReturnPrompt string `json:"return_prompt"` + PromptGuard []string `json:"prompt_guard"` + Optimize bool `json:"optimize_prompt,omitempty"` + UseInternet bool `json:"use_internet,omitempty"` + } + // PromptResponse represents a response structure for prompt create API. + PromptResponse struct { + Success bool `json:"success"` + ID string `json:"prompt_engine_id"` + } + // PromptEngine represents a prompt engine. + PromptEngine struct { + ID string `json:"id"` + Prompt string `json:"prompt"` + Inputs any `json:"inputs"` + ReturnPrompt any `json:"return_prompt"` + CreatedAt time.Time `json:"created_at"` + } + // PromptListResponse represents a response structure for prompt list API. + PromptListResponse struct { + Success bool `json:"success"` + PromptEngines []PromptEngine `json:"prompt_engines"` + Page int `json:"page"` + Limit int `json:"limit"` + HasMore bool `json:"has_more"` + } +) + +// PromptGet gets a specific prompt. +// +// GET https://api.jigsawstack.com/v1/prompt_engine/{id} +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/retrieve +func (j *JigsawStack) PromptGet( + ctx context.Context, + id string, +) (response PromptEngine, err error) { + uri := j.baseURL + string(promptCreateEndpoint) + "/" + id + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + uri, + ) + if err != nil { + return + } + var resp PromptEngine + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// PromptList lists prompts. +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/list +// +// GET https://api.jigsawstack.com/v1/prompt_engine +func (j *JigsawStack) PromptList( + ctx context.Context, + page int, + limit int, +) (response PromptListResponse, err error) { + uri := j.baseURL + string(promptCreateEndpoint) + "?page=" + strconv.Itoa(page) + "&limit=" + strconv.Itoa(limit) + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + uri, + ) + if err != nil { + return + } + var resp PromptListResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// PromptCreate creates a prompt. +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/create +// +// POST https://api.jigsawstack.com/v1/prompt_engine +func (j *JigsawStack) PromptCreate( + ctx context.Context, + request PromptCreateRequest, +) (response PromptResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(promptCreateEndpoint), + builders.WithBody(request), + ) + if err != nil { + return + } + var resp PromptResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// PromptDelete deletes a specific prompt. +// +// https://api.jigsawstack.com/v1/prompt_engine/{id} +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/delete +func (j *JigsawStack) PromptDelete( + ctx context.Context, + id string, +) (response PromptResponse, err error) { + // TODO: may need to sanitize the id + uri := j.baseURL + string(promptCreateEndpoint) + "/" + id + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodDelete, + uri, + ) + if err != nil { + return + } + var resp PromptResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// PromptRun runs a specific prompt with the given inputs. +// +// https://api.jigsawstack.com/v1/prompt_engine/{id} +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/run +func (j *JigsawStack) PromptRun( + ctx context.Context, + id string, + inputs map[string]any, +) (response PromptRunResponse, err error) { + // TODO: may need to sanitize the id + uri := j.baseURL + string(promptCreateEndpoint) + "/" + id + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + uri, + builders.WithBody(inputs), + ) + if err != nil { + return + } + var resp PromptRunResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// PromptRunDirect runs new prompt with the given inputs. +// +// https://docs.jigsawstack.com/api-reference/prompt-engine/run-direct +// +// https://api.jigsawstack.com/v1/prompt_engine/run +func (j *JigsawStack) PromptRunDirect( + ctx context.Context, + request PromptCreateRequest, + inputs map[string]any, +) (response PromptRunResponse, err error) { + type combinedRequest struct { + PromptCreateRequest + Inputs map[string]any `json:"inputs"` + } + var combinedReq combinedRequest + combinedReq.PromptCreateRequest = request + combinedReq.Inputs = inputs + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(promptCreateEndpoint), + builders.WithBody(combinedReq), + ) + if err != nil { + return + } + var resp PromptRunResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/prompt_test.go b/extensions/jigsawstack/prompt_test.go new file mode 100644 index 0000000..ac55273 --- /dev/null +++ b/extensions/jigsawstack/prompt_test.go @@ -0,0 +1,58 @@ +package jigsawstack_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +func TestJigsawStack_PromptCreate(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping integration test") + } + a := assert.New(t) + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.PromptCreate(context.Background(), jigsawstack.PromptCreateRequest{ + Prompt: ` +You are a helpful assistant that answers questions based on the provided context. +Your job is to provide code completions based on the provided context. + `, + Inputs: []jigsawstack.PromptCreateInput{ + { + Key: "context", + Optional: false, + InitialValue: ` + +def main(): + print("Hello, World!") + +if __name__ == "__main__": + main() + + `, + }, + }, + }) + a.NoError(err) + t.Logf("response: %v", resp) + t.Fail() +} +func TestJigsawStack_PromptGet(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping integration test") + } + a := assert.New(t) + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.PromptGet(context.Background(), "test") + a.NoError(err) + a.NotEmpty(resp.Prompt) +} diff --git a/extensions/jigsawstack/sql.go b/extensions/jigsawstack/sql.go new file mode 100644 index 0000000..bc30fda --- /dev/null +++ b/extensions/jigsawstack/sql.go @@ -0,0 +1,53 @@ +package jigsawstack + +import ( + "context" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + textToSQLEndpoint Endpoint = "/v1/ai/sql" +) + +type ( + // TextToSQLResponse represents a response structure for text to SQL API. + TextToSQLResponse struct { + Success bool `json:"success"` + SQL string `json:"sql"` + } +) + +// TextToSQL converts the text to SQL. +// +// Max text character is 5000. +func (j *JigsawStack) TextToSQL( + ctx context.Context, + prompt string, + sqlSchema string, +) (response TextToSQLResponse, err error) { + body := struct { + Prompt string `json:"prompt"` + SQLSchema string `json:"sql_schema"` + }{ + Prompt: prompt, + SQLSchema: sqlSchema, + } + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(textToSQLEndpoint), + builders.WithBody(body), + ) + if err != nil { + return + } + var resp TextToSQLResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/sql_test.go b/extensions/jigsawstack/sql_test.go new file mode 100644 index 0000000..5ac671c --- /dev/null +++ b/extensions/jigsawstack/sql_test.go @@ -0,0 +1,33 @@ +package jigsawstack_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestJigsawStack_TextToSQL tests the TextToSQL method of the JigsawStack client. +func TestJigsawStack_TextToSQL(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping unit test") + } + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.TextToSQL(ctx, "select all users", ` +CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255), + email VARCHAR(255), + age INT +); +`) + a.NoError(err) + a.NotEmpty(resp.SQL) +} diff --git a/extensions/jigsawstack/storage.go b/extensions/jigsawstack/storage.go new file mode 100644 index 0000000..02e4ef5 --- /dev/null +++ b/extensions/jigsawstack/storage.go @@ -0,0 +1,210 @@ +package jigsawstack + +import ( + "context" + "fmt" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + uploadEndpoint Endpoint = "/v1/store/file" + kvEndpoint Endpoint = "/v1/store/kv" +) + +type ( + // StorageResponse represents a response structure for file API. + StorageResponse struct { + Success bool `json:"success"` + Message string `json:"message,omitempty"` + URL string `json:"url"` + Key string `json:"key"` + Value string `json:"value,omitempty"` + } +) + +// https://docs.jigsawstack.com/api-reference/store/file/get +// Upload Retrieve Delete + +// FileAdd uploads a file to the Jigsaw Stack file store. +// +// https://docs.jigsawstack.com/api-reference/store/file/add +// +// POST https://api.jigsawstack.com/v1/store/file +func (j *JigsawStack) FileAdd( + ctx context.Context, + key string, + contentType string, + content string, +) (string, error) { + url := j.baseURL + string(uploadEndpoint) + "?key=" + key + var body = struct { + Blob string `json:"blob"` + }{ + Blob: content, + } + + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + url, + builders.WithBody(body), + builders.WithContentType(contentType), + ) + if err != nil { + return "", err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + if !resp.Success { + return "", fmt.Errorf("failed to upload file: %s", resp.Message) + } + return "", nil +} + +// FileGet retrieves a file from the Jigsaw Stack file store. +// +// https://docs.jigsawstack.com/api-reference/store/file/get +// +// GET https://api.jigsawstack.com/v1/store/file/{fileName} +func (j *JigsawStack) FileGet(ctx context.Context, fileName string) (string, error) { + // TODO: may need to santize the fileName + url := j.baseURL + string(uploadEndpoint) + "/" + fileName + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + url, + ) + if err != nil { + return "", err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + if !resp.Success { + return "", fmt.Errorf("failed to retrieve file: %s", resp.Message) + } + // TODO: may need to return the file content from url + return resp.Message, nil +} + +// FileDelete deletes a file from the Jigsaw Stack file store. +// +// https://docs.jigsawstack.com/api-reference/store/file/delete +// +// DELETE https://api.jigsawstack.com/v1/store/file/{fileName} +func (j *JigsawStack) FileDelete(fileName string) error { + // TODO: may need to santize the fileName + url := j.baseURL + string(uploadEndpoint) + "/" + fileName + req, err := builders.NewRequest( + context.Background(), + j.header, + http.MethodDelete, + url, + ) + if err != nil { + return err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return err + } + if !resp.Success { + return fmt.Errorf("failed to delete file: %s", resp.Message) + } + return nil +} + +// KVAdd adds a key value pair to the Jigsaw Stack key-value store. +// +// https://docs.jigsawstack.com/api-reference/store/kv/add +// +// POST https://api.jigsawstack.com/v1/store/kv +func (j *JigsawStack) KVAdd( + ctx context.Context, + key string, + value string, +) error { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(kvEndpoint), + builders.WithBody(map[string]string{ + "key": key, + "value": value, + }), + ) + if err != nil { + return err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return err + } + return nil +} + +// KVGet retrieves a key value pair from the Jigsaw Stack key-value store. +// +// https://docs.jigsawstack.com/api-reference/store/kv/get +// +// GET https://api.jigsawstack.com/v1/store/kv/{key} +func (j *JigsawStack) KVGet( + ctx context.Context, + key string, +) (string, error) { + url := j.baseURL + string(kvEndpoint) + "/" + key + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + url, + ) + if err != nil { + return "", err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + return "", nil +} + +// KVDelete deletes a key value pair from the Jigsaw Stack key-value store. +// +// https://docs.jigsawstack.com/api-reference/store/kv/delete +// +// DELETE https://api.jigsawstack.com/v1/store/kv/{key} +func (j *JigsawStack) KVDelete( + ctx context.Context, + key string, +) (string, error) { + url := j.baseURL + string(kvEndpoint) + "/" + key + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodDelete, + url, + ) + if err != nil { + return "", err + } + var resp StorageResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + return resp.Message, nil +} diff --git a/extensions/jigsawstack/storage_test.go b/extensions/jigsawstack/storage_test.go new file mode 100644 index 0000000..5728c21 --- /dev/null +++ b/extensions/jigsawstack/storage_test.go @@ -0,0 +1,26 @@ +package jigsawstack_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestJigsawStack_FileAdd tests the FileAdd method of the JigsawStack client. +func TestJigsawStack_FileAdd(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping unit test") + } + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.FileAdd(ctx, "test", "text/plain", "hello world") + a.NoError(err) + a.NotEmpty(resp) +} diff --git a/extensions/jigsawstack/tts.mp3 b/extensions/jigsawstack/tts.mp3 new file mode 100644 index 0000000..4e38023 Binary files /dev/null and b/extensions/jigsawstack/tts.mp3 differ diff --git a/extensions/jigsawstack/visual.go b/extensions/jigsawstack/visual.go new file mode 100644 index 0000000..4ee030b --- /dev/null +++ b/extensions/jigsawstack/visual.go @@ -0,0 +1,212 @@ +package jigsawstack + +import ( + "context" + "net/http" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + vOCREndpoint Endpoint = "/v1/vocr" + vObjectEndpoint Endpoint = "/v1/ai/object_detection" + imageGenerationEndpoint Endpoint = "/v1/ai/image_generation" +) + +type ( + // model + // string + // default: "sdxl" + // + // The model to use for the generation. Default is sdxl + // + // sd1.5 - Stable Diffusion v1.5 + // sdxl - Stable Diffusion XL + // ead1.0 - Anime Diffusion + // rv1.3 - Realistic Vision v1.3 + // rv3 - Realistic Vision v3 + // rv5.1 - Realistic Vision v5.1 + // ar1.8 - AbsoluteReality v1.8.1 + + // ImageGenerationRequest represents a request structure for image + // generation API. + ImageGenerationRequest struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Size string `json:"size"` + Width int `json:"width"` + Height int `json:"height"` + } + // ImageGenerationResponse represents a response structure for image + // generation API. + ImageGenerationResponse struct { + Success bool `json:"success"` + Image string `json:"image"` + } + // visionRequest represents a request structure for VOCR API. + visionRequest struct { + // Prompt is the prompt used in ocr. If the request is for + // object detection, this field is not required. + Prompt string `json:"prompt,omitempty"` + // URL is the url of the image to use as the image. + // + // Not required if the StoreKey is not provided. + URL string `json:"image_url"` + // Key is the key of the file to use as the image. + // + // Not required if the ImageURL is not provided. + Key string `json:"file_store_key"` + } + // VOCRResponse represents a response structure for VOCR API. + VOCRResponse struct { + // Success is a boolean indicating whether the request was + // successful. + Success bool `json:"success"` + // Context is the context of the image. + Context string `json:"context"` + // Width is the width of the image. + Width int `json:"width"` + // Height is the height of the image. + Height int `json:"height"` + // Tags is a list of tags detected in the image. + Tags []string `json:"tags"` + // HasText is a boolean indicating whether the image contains + // text. + HasText bool `json:"has_text"` + // Sections is a list of sections detected in the image. + Sections []any `json:"sections"` + } + // VisionObjectResponse represents a response structure for VOD API. + VisionObjectResponse struct { + // Success is a boolean indicating whether the request was + Success bool `json:"success"` + // Width is the width of the image. + Width int `json:"width"` + // Height is the height of the image. + Height int `json:"height"` + // Tags is a list of tags detected in the image. + Tags []string `json:"tags"` + // Objects is a list of objects detected in the image. + Objects []struct { + Name string `json:"name"` + Confidence float64 `json:"confidence"` + Bounds struct { + TopLeft struct { + X int `json:"x"` + Y int `json:"y"` + } `json:"top_left"` + TopRight struct { + X int `json:"x"` + Y int `json:"y"` + } `json:"top_right"` + BottomRight struct { + X int `json:"x"` + Y int `json:"y"` + } `json:"bottom_right"` + BottomLeft struct { + X int `json:"x"` + Y int `json:"y"` + } `json:"bottom_left"` + Width int `json:"width"` + Height int `json:"height"` + } `json:"bounds"` + } `json:"objects"` + } +) + +// VCOROption is the option for VOCR. +type VCOROption func(*visionRequest) + +// WithKey sets the key of the file to use as the image. +func WithKey(key string) VCOROption { + return func(params *visionRequest) { params.Key = key } +} + +// WithURL sets the URL of the image to use as the image. +func WithURL(url string) VCOROption { + return func(params *visionRequest) { params.URL = url } +} + +// VOCR performs a visual object recognition (VOCR) task on an image. +// +// POST https://api.jigsawstack.com/v1/vocr +// +// https://docs.jigsawstac.com/api-reference/ai/vision +func (j *JigsawStack) VOCR( + ctx context.Context, + prompt string, + opt VCOROption, +) (string, error) { + params := visionRequest{ + Prompt: prompt, + } + opt(¶ms) + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(vOCREndpoint), + builders.WithBody(params), + ) + if err != nil { + return "", err + } + var resp VOCRResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + return "", nil +} + +// VisionObjectDetection performs a visual object detection (VOD) task on an +// image. +// +// POST https://api.jigsawstack.com/v1/ai/object_detection +// +// https://docs.jigsawstack.com/api-reference/ai/object-detection +func (j *JigsawStack) VisionObjectDetection( + ctx context.Context, + params visionRequest, +) (string, error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(vObjectEndpoint), + builders.WithBody(params), + ) + if err != nil { + return "", err + } + var resp VisionObjectResponse + err = j.sendRequest(req, &resp) + if err != nil { + return "", err + } + return "", nil + +} + +// ImageGeneration generates an image from a prompt and parameters. +func (j *JigsawStack) ImageGeneration( + ctx context.Context, + request ImageGenerationRequest, +) (response ImageGenerationResponse, err error) { + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodPost, + j.baseURL+string(imageGenerationEndpoint), + builders.WithBody(request), + ) + if err != nil { + return + } + var resp ImageGenerationResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/web.go b/extensions/jigsawstack/web.go new file mode 100644 index 0000000..e3034d1 --- /dev/null +++ b/extensions/jigsawstack/web.go @@ -0,0 +1,106 @@ +package jigsawstack + +import ( + "context" + "net/http" + "time" + + "github.com/conneroisu/groq-go/pkg/builders" +) + +const ( + webSearchEndpoint Endpoint = "/v1/web/search" + webSuggestEndpoint Endpoint = "/v1/web/search/suggest" +) + +type ( + // WebSearchSuggestions is the response for the web search suggestions + // api. + WebSearchSuggestions struct { + Success bool `json:"success"` + Suggestions []string `json:"suggestions"` + } + // WebSearchResponse is the response for the web search api. + WebSearchResponse struct { + Success bool `json:"success"` + Query string `json:"query"` + SpellFixed string `json:"spell_fixed"` + IsSafe bool `json:"is_safe"` + AiOverview string `json:"ai_overview"` + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + Content string `json:"content"` + SiteName string `json:"site_name"` + SiteLongName string `json:"site_long_name"` + Age time.Time `json:"age"` + Language string `json:"language"` + IsSafe bool `json:"is_safe"` + Favicon string `json:"favicon"` + Snippets []string `json:"snippets"` + RelatedIndex []struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + IsSafe bool `json:"is_safe"` + } `json:"related_index,omitempty"` + Thumbnail string `json:"thumbnail,omitempty"` + } `json:"results"` + } +) + +// WebSearch performs a web search api call over a query string. +// +// GET https://api.jigsawstack.com/v1/web/search +// +// https://docs.jigsawstack.com/api-reference/web/search +func (j *JigsawStack) WebSearch( + ctx context.Context, + query string, +) (response WebSearchResponse, err error) { + uri := j.baseURL + string(webSearchEndpoint) + "?query=" + query + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + uri, + ) + if err != nil { + return + } + var resp WebSearchResponse + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} + +// WebSearchSuggestions performs a web search suggestions api call over a query +// string. +// +// GET https://api.jigsawstack.com/v1/web/search/suggest +// +// https://docs.jigsawstack.com/api-reference/web/search +func (j *JigsawStack) WebSearchSuggestions( + ctx context.Context, + query string, +) (response WebSearchSuggestions, err error) { + uri := j.baseURL + string(webSuggestEndpoint) + "?query=" + query + req, err := builders.NewRequest( + ctx, + j.header, + http.MethodGet, + uri, + ) + if err != nil { + return + } + var resp WebSearchSuggestions + err = j.sendRequest(req, &resp) + if err != nil { + return + } + return resp, nil +} diff --git a/extensions/jigsawstack/web_test.go b/extensions/jigsawstack/web_test.go new file mode 100644 index 0000000..c26a858 --- /dev/null +++ b/extensions/jigsawstack/web_test.go @@ -0,0 +1,42 @@ +package jigsawstack_test + +import ( + "context" + "testing" + + "github.com/conneroisu/groq-go/extensions/jigsawstack" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestJigsawStack_WebSearch tests the WebSearch method of the JigsawStack client. +func TestJigsawStack_WebSearch(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping unit test") + } + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.WebSearch(ctx, "hello world golang") + a.NoError(err) + a.NotEmpty(resp.Results) +} + +// TestJigsawStack_WebSearchSuggestions tests the WebSearchSuggestions method of the JigsawStack client. +func TestJigsawStack_WebSearchSuggestions(t *testing.T) { + if !test.IsIntegrationTest() { + t.Skip("Skipping unit test") + } + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + a.NoError(err) + j, err := jigsawstack.NewJigsawStack(apiKey) + a.NoError(err) + resp, err := j.WebSearchSuggestions(ctx, "hello") + a.NoError(err) + a.NotEmpty(resp.Suggestions) +} diff --git a/extensions/toolhouse/run.go b/extensions/toolhouse/execute.go similarity index 78% rename from extensions/toolhouse/run.go rename to extensions/toolhouse/execute.go index 469306c..d2242ce 100644 --- a/extensions/toolhouse/run.go +++ b/extensions/toolhouse/execute.go @@ -19,20 +19,6 @@ type ( } ) -// MustRun runs the extension on the given history. -// -// It panics if an error occurs. -func (e *Toolhouse) MustRun( - ctx context.Context, - response groq.ChatCompletionResponse, -) []groq.ChatCompletionMessage { - respH, err := e.Run(ctx, response) - if err != nil { - panic(err) - } - return respH -} - // Run runs the extension on the given history. func (e *Toolhouse) Run( ctx context.Context, @@ -41,8 +27,8 @@ func (e *Toolhouse) Run( var respH []groq.ChatCompletionMessage var toolCall tools.ToolCall e.logger.Debug("Running Toolhouse extension", "response", response) - if response.Choices[0].FinishReason != groq.FinishReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { - return nil, fmt.Errorf("Not a function call") + if response.Choices[0].FinishReason != groq.ReasonFunctionCall && response.Choices[0].FinishReason != "tool_calls" { + return nil, fmt.Errorf("not a function call") } for _, toolCall = range response.Choices[0].Message.ToolCalls { req, err := builders.NewRequest( diff --git a/extensions/toolhouse/run_test.go b/extensions/toolhouse/execute_test.go similarity index 97% rename from extensions/toolhouse/run_test.go rename to extensions/toolhouse/execute_test.go index 0ff5386..ac207a0 100644 --- a/extensions/toolhouse/run_test.go +++ b/extensions/toolhouse/execute_test.go @@ -66,7 +66,7 @@ func TestRun(t *testing.T) { Choices: []groq.ChatCompletionChoice{ { Message: history[0], - FinishReason: groq.FinishReasonFunctionCall, + FinishReason: groq.ReasonFunctionCall, }, }, }) diff --git a/extensions/toolhouse/toolhouse_test.go b/extensions/toolhouse/toolhouse_test.go index 1100300..b4210d5 100644 --- a/extensions/toolhouse/toolhouse_test.go +++ b/extensions/toolhouse/toolhouse_test.go @@ -7,12 +7,13 @@ import ( "github.com/conneroisu/groq-go" "github.com/conneroisu/groq-go/extensions/toolhouse" + "github.com/conneroisu/groq-go/pkg/models" "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" ) func TestUnitExtension(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip("Skipping Toolhouse extension test") } a := assert.New(t) @@ -36,7 +37,7 @@ func TestUnitExtension(t *testing.T) { tooling, err := ext.GetTools(ctx) a.NoError(err) re, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: history, Tools: tooling, ToolChoice: "required", @@ -47,7 +48,7 @@ func TestUnitExtension(t *testing.T) { a.NoError(err) history = append(history, r...) finalr, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama3Groq70B8192ToolUsePreview, + Model: models.ModelLlama3Groq70B8192ToolUsePreview, Messages: history, MaxTokens: 2000, }) diff --git a/extensions/toolhouse/tools.go b/extensions/toolhouse/tools.go index cacaf3d..a6f342b 100644 --- a/extensions/toolhouse/tools.go +++ b/extensions/toolhouse/tools.go @@ -2,28 +2,12 @@ package toolhouse import ( "context" - "encoding/json" - "fmt" - "io" "net/http" "github.com/conneroisu/groq-go/pkg/builders" "github.com/conneroisu/groq-go/pkg/tools" ) -// MustGetTools returns a list of tools that the extension can use. -// -// It panics if an error occurs. -func (e *Toolhouse) MustGetTools( - ctx context.Context, -) []tools.Tool { - tools, err := e.GetTools(ctx) - if err != nil { - panic(err) - } - return tools -} - // GetTools returns a list of tools that the extension can use. func (e *Toolhouse) GetTools( ctx context.Context, @@ -45,20 +29,8 @@ func (e *Toolhouse) GetTools( if err != nil { return nil, err } - resp, err := e.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("request failed: %s", resp.Status) - } - bdy, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w: %s", err, string(bdy)) - } var tooling []tools.Tool - err = json.Unmarshal(bdy, &tooling) + err = e.sendRequest(req, &tooling) if err != nil { return nil, err } diff --git a/moderation.go b/moderation.go index 0c83169..98ce0df 100644 --- a/moderation.go +++ b/moderation.go @@ -6,192 +6,17 @@ import ( "strings" "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/moderation" ) -// ModerationRequest represents a request structure for moderation API. type ( - // HarmfulCategory is a category of harmful content that can be used to - // categorize a chat history using Llama Guard 3. - // - // [Llama Guard 3](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3/) - // - // Llama Guard 3 provides content safety support for the following languages: English, French, German, Hindi, Italian, Portuguese, Spanish, and Thai. - // - // string - HarmfulCategory string - ModerationRequest struct { - // Input string `json:"input,omitempty"` // Input is the input text to be moderated. - Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model. - Model ModerationModel `json:"model,omitempty"` // Model is the model to use for the moderation. - } - // Moderation represents one of possible moderation results. + // Moderation represents the response of a moderation request. Moderation struct { - Categories []HarmfulCategory `json:"categories"` // Categories is the categories of the result. - Flagged bool `json:"flagged"` // Flagged is the flagged of the result. - } -) - -const ( - // CategoryViolentCrimes (S1) is the violent crimes category. - // - // Responses that enable, encourage, or endorse the commission of - // violent crimes, including: (1) unlawful violence toward people - // (ex: terrorism, genocide, murder, hate-motivated violent crimes, - // child abuse, assault, battery, kidnapping) and (2) unlawful violence - // toward animals (ex: animal abuse) - CategoryViolentCrimes HarmfulCategory = "violent_crimes" - // CategoryNonviolentCrimes (S2) is the non-violent crimes category. - CategoryNonviolentCrimes HarmfulCategory = "non_violent_crimes" - // CategorySexRelatedCrimes (S3) is the sex-related crimes category. - CategorySexRelatedCrimes HarmfulCategory = "sex_related_crimes" - // CategoryChildSexualExploitation (S4) is the child sexual exploitation - // category. - // - // Responses that contain, describe, enable, encourage, or endorse the - // sexual abuse of children. - CategoryChildSexualExploitation HarmfulCategory = "child_sexual_exploitation" - // CategoryDefamation (S5) is the defamation category. - // - // Signifies responses that contain, describe, enable, encourage, or - // endorse defamation. - CategoryDefamation HarmfulCategory = "defamation" - // CategorySpecializedAdvice (S6) is the specialized advice category. - // - // Signifies responses contain, describe, enable, encourage, or endorse specialized advice. - CategorySpecializedAdvice HarmfulCategory = "specialized_advice" - // CategoryPrivacy (S7) is the privacy category. - // - // Responses contain, describe, enable, encourage, or endorse privacy. - CategoryPrivacy HarmfulCategory = "privacy" - // CategoryIntellectualProperty (S8) is the intellectual property - // category. Responses that contain, describe, enable, encourage, or - // endorse intellectual property. - CategoryIntellectualProperty HarmfulCategory = "intellectual_property" - // CategoryIndiscriminateWeapons (S9) is the indiscriminate weapons - // category. - // - // Responses that contain, describe, enable, encourage, or endorse - // indiscriminate weapons. - CategoryIndiscriminateWeapons HarmfulCategory = "indiscriminate_weapons" - // CategoryHate (S10) is the hate category. - // - // Responses contain, describe, enable, encourage, or endorse hate. - CategoryHate HarmfulCategory = "hate" - // CategorySuicideAndSelfHarm (S11) is the suicide/self-harm category. - // - // Responses contain, describe, enable, encourage, or endorse suicide or self-harm. - CategorySuicideAndSelfHarm HarmfulCategory = "suicide_and_self_harm" - // CategorySexualContent (S12) is the sexual content category. - // - // Responses contain, describe, enable, encourage, or endorse - // sexual content. - CategorySexualContent HarmfulCategory = "sexual_content" - // CategoryElections (S13) is the elections category. - // - // Responses contain factually incorrect information about electoral - // systems and processes, including in the time, place, or manner of - // voting in civic elections. - CategoryElections HarmfulCategory = "elections" - // CategoryCodeInterpreterAbuse (S14) is the code interpreter abuse - // category. - // - // Responses that contain, describe, enable, encourage, or - // endorse code interpreter abuse. - CategoryCodeInterpreterAbuse HarmfulCategory = "code_interpreter_abuse" -) - -var ( - // sectionMap is a map of sections to categories. - // - // S1: Violent Crimes. - // - // AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to: - // - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping) - // - Unlawful violence toward animals (ex: animal abuse) - // - // S2: Non-Violent Crimes. - // - // AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to: - // - Financial crimes (ex: fraud, scams, money laundering) - // - Property crimes (ex: burglary, robbery, arson, vandalism) - // - Drug crimes (ex: creating or trafficking narcotics) - // - Weapons crimes (ex: producing unlicensed firearms) - // - Cyber crimes (ex: hacking, spyware, malware) - // - // S3: Sex Crimes. - // - // AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to: - // - Human trafficking - // - Sexual assault (ex: rape) - // - Sexual harassment (ex: groping) - // - Lewd conduct - // - // S4: Child Exploitation. - // - // AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children. - // - // S5: Defamation. - // - // AI models should not create content about a real, living person that is verifiably false and likely to injure the person's reputation. - // - // S6: Specialized Advice. - // - // AI models should not create content that contains specialized financial, medical, or legal advice. - // - // S7: Privacy. - // - // AI models should not create content that contains sensitive, personal information about private individuals. - // - // S8: Intellectual Property. - // - // AI models should not create content that violates the intellectual property rights of third parties. - // - // S9: Indiscriminate Weapons. - // I models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to: - // - Chemical weapons (ex: nerve gas) - // - Biological weapons (ex: anthrax) - // - Radiological weapons (ex: salted bombs) - // - Nuclear weapons (ex: atomic warheads) - // - High-yield explosive weapons (ex: cluster munitions) - // - // S10: Hate. - // - // AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes. - // - // S11: Self-Harm. - // - // AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to: - // - Suicide - // - Self-injury (ex: cutting) - // - Disordered eating - // S12: Sexual Content. - // - // AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts. - // - // S13: Elections. - // - // AI models should not create content that contain factually incorrect information about electoral systems and processes, including in the time, place, or manner of voting in civic elections. - // - // S14: Code Interpreter Abuse. - // - // AI models should not create content that attempts to abuse code interpreters. Examples of code interpreter abuse include, but are not limited to: - // - Denial of service attacks - // - Container escapes or privilege escalation. - sectionMap = map[string]HarmfulCategory{ - "S1": CategoryViolentCrimes, - "S2": CategoryNonviolentCrimes, - "S3": CategorySexRelatedCrimes, - "S4": CategoryChildSexualExploitation, - "S5": CategoryDefamation, - "S6": CategorySpecializedAdvice, - "S7": CategoryPrivacy, - "S8": CategoryIntellectualProperty, - "S9": CategoryIndiscriminateWeapons, - "S10": CategoryHate, - "S11": CategorySuicideAndSelfHarm, - "S12": CategorySexualContent, - "S13": CategoryElections, - "S14": CategoryCodeInterpreterAbuse, + // Categories is the categories of the result. + Categories []moderation.HarmfulCategory `json:"categories"` + // Flagged is the flagged status of the result. + Flagged bool `json:"flagged"` } ) @@ -199,14 +24,21 @@ var ( // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderate( ctx context.Context, - request ModerationRequest, + messages []ChatCompletionMessage, + model models.ModerationModel, ) (response Moderation, err error) { req, err := builders.NewRequest( ctx, c.header, http.MethodPost, - c.fullURL(chatCompletionsSuffix, withModel(model(request.Model))), - builders.WithBody(&request), + c.fullURL(chatCompletionsSuffix, withModel(model)), + builders.WithBody(&struct { + Messages []ChatCompletionMessage `json:"messages"` + Model models.ModerationModel `json:"model,omitempty"` + }{ + Messages: messages, + Model: model, + }), ) if err != nil { return @@ -225,7 +57,7 @@ func (c *Client) Moderate( for _, s := range split { response.Categories = append( response.Categories, - sectionMap[strings.TrimSpace(s)], + moderation.SectionMap[strings.TrimSpace(s)], ) } } diff --git a/moderation_test.go b/moderation_test.go new file mode 100644 index 0000000..df2cbe7 --- /dev/null +++ b/moderation_test.go @@ -0,0 +1,29 @@ +package groq_test + +import ( + "context" + "testing" + + groq "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestModeration(t *testing.T) { + a := assert.New(t) + ctx := context.Background() + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleModerationEndpoint) + mod, err := client.Moderate(ctx, + []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "I want to kill them.", + }, + }, + models.ModelLlamaGuard38B, + ) + a.NoError(err) + a.NotEmpty(mod.Categories) +} diff --git a/pkg/builders/options.go b/pkg/builders/options.go new file mode 100644 index 0000000..ca13d94 --- /dev/null +++ b/pkg/builders/options.go @@ -0,0 +1,33 @@ +package builders + +import "net/url" + +// WithBody sets the body for a request. +func WithBody(body any) RequestOption { + return func(args *requestOptions) { + args.body = body + } +} + +// WithContentType sets the content type for a request. +func WithContentType(contentType string) RequestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +type ( + // Querier is an interface for a request querier. + // + // It allows for modifying the URL before it is sent. + Querier interface { + URLQuery(url *url.URL) + } +) + +// WithQuerier sets the querier for a request. +func WithQuerier(querier Querier) RequestOption { + return func(args *requestOptions) { + args.querier = querier + } +} diff --git a/pkg/builders/requests.go b/pkg/builders/requests.go index 71f7244..62eecd8 100644 --- a/pkg/builders/requests.go +++ b/pkg/builders/requests.go @@ -26,8 +26,9 @@ type ( } defaultRequestBuilder struct{} requestOptions struct { - body any - header http.Header + body any + header http.Header + querier Querier } // RequestOption is an option for a request. RequestOption func(*requestOptions) @@ -74,20 +75,6 @@ func (b *defaultRequestBuilder) Build( return } -// WithBody sets the body for a request. -func WithBody(body any) RequestOption { - return func(args *requestOptions) { - args.body = body - } -} - -// WithContentType sets the content type for a request. -func WithContentType(contentType string) RequestOption { - return func(args *requestOptions) { - args.header.Set("Content-Type", contentType) - } -} - // NewRequest creates a new request. func NewRequest( ctx context.Context, @@ -113,5 +100,8 @@ func NewRequest( return nil, err } c.SetCommonHeaders(req) + if args.querier != nil { + args.querier.URLQuery(req.URL) + } return req, nil } diff --git a/pkg/builders/urls.go b/pkg/builders/urls.go new file mode 100644 index 0000000..88f6939 --- /dev/null +++ b/pkg/builders/urls.go @@ -0,0 +1,8 @@ +package builders + +type ( + // URLComputer computes URLs for a given client. + URLComputer interface { + ComputeURLs() + } +) diff --git a/pkg/groqerr/api.go b/pkg/groqerr/api.go new file mode 100644 index 0000000..2323f00 --- /dev/null +++ b/pkg/groqerr/api.go @@ -0,0 +1,46 @@ +package groqerr + +import "fmt" + +type ( + // ErrContentFieldsMisused is an error that occurs when both Content and + // MultiContent properties are set. + ErrContentFieldsMisused struct{} + // ErrToolNotFound is returned when a tool is not found. + ErrToolNotFound struct { + ToolName string + } +) + +// Error implements the error interface. +func (e ErrContentFieldsMisused) Error() string { + return fmt.Errorf("can't use both Content and MultiContent properties simultaneously"). + Error() +} + +type ( + // ErrRequest is a request error. + ErrRequest struct { + HTTPStatusCode int + Err error + } +) + +// Error implements the error interface. +func (e *ErrRequest) Error() string { + return fmt.Sprintf( + "error, status code: %d, message: %s", + e.HTTPStatusCode, + e.Err, + ) +} + +// Unwrap unwraps the error. +func (e *ErrRequest) Unwrap() error { + return e.Err +} + +// Error implements the error interface. +func (e ErrToolNotFound) Error() string { + return fmt.Sprintf("tool %s not found", e.ToolName) +} diff --git a/pkg/groqerr/doc.go b/pkg/groqerr/doc.go new file mode 100644 index 0000000..cb01d60 --- /dev/null +++ b/pkg/groqerr/doc.go @@ -0,0 +1,2 @@ +// Package groqerr provides error types for the groq-go library. +package groqerr diff --git a/pkg/groqerr/stream.go b/pkg/groqerr/stream.go new file mode 100644 index 0000000..e4f3b7e --- /dev/null +++ b/pkg/groqerr/stream.go @@ -0,0 +1,94 @@ +package groqerr + +import ( + "encoding/json" + "fmt" + "io" + "strings" +) + +type ( + // ErrTooManyEmptyStreamMessages is returned when the stream has sent + // too many empty messages. + ErrTooManyEmptyStreamMessages struct{} + + // ErrorResponse is the response returned by the Groq API. + ErrorResponse struct { + Error *APIError `json:"error,omitempty"` + } + + // APIError provides error information returned by the Groq API. + APIError struct { + // Code is the code of the error. + Code any `json:"code,omitempty"` + // Message is the message of the error. + Message string `json:"message"` + // Param is the param of the error. + Param *string `json:"param,omitempty"` + // Type is the type of the error. + Type string `json:"type"` + // HTTPStatusCode is the status code of the error. + HTTPStatusCode int `json:"-"` + } + + // ErrorBuffer is a buffer that allows for appending errors. + ErrorBuffer interface { + io.Writer + Len() int + Bytes() []byte + } +) + +// Error method implements the error interface on APIError. +func (e *APIError) Error() string { + if e.HTTPStatusCode > 0 { + return fmt.Sprintf( + "error, status code: %d, message: %s", + e.HTTPStatusCode, + e.Message, + ) + } + return e.Message +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (e *APIError) UnmarshalJSON(data []byte) (err error) { + var rawMap map[string]json.RawMessage + err = json.Unmarshal(data, &rawMap) + if err != nil { + return + } + err = json.Unmarshal(rawMap["message"], &e.Message) + if err != nil { + var messages []string + err = json.Unmarshal(rawMap["message"], &messages) + if err != nil { + return + } + e.Message = strings.Join(messages, ", ") + } + // optional fields + if _, ok := rawMap["param"]; ok { + err = json.Unmarshal(rawMap["param"], &e.Param) + if err != nil { + return + } + } + if _, ok := rawMap["code"]; !ok { + return nil + } + // if the api returned a number, we need to force an integer + // since the json package defaults to float64 + var intCode int + err = json.Unmarshal(rawMap["code"], &intCode) + if err == nil { + e.Code = intCode + return nil + } + return json.Unmarshal(rawMap["code"], &e.Code) +} + +// Error returns the error message. +func (e ErrTooManyEmptyStreamMessages) Error() string { + return "stream has sent too many empty messages" +} diff --git a/pkg/groqerr/stream_test.go b/pkg/groqerr/stream_test.go new file mode 100644 index 0000000..234de5e --- /dev/null +++ b/pkg/groqerr/stream_test.go @@ -0,0 +1 @@ +package groqerr_test diff --git a/models.go b/pkg/models/models.go similarity index 88% rename from models.go rename to pkg/models/models.go index 6d21863..c593350 100644 --- a/models.go +++ b/pkg/models/models.go @@ -1,33 +1,22 @@ // Code generated by groq-modeler DO NOT EDIT. // -// Created at: 2024-10-26 10:01:35 +// Created at: 2024-11-04 11:35:13 // // groq-modeler Version 1.1.2 -package groq +package models type ( - model string - - // Endpoint is the endpoint for the groq api. - // string - Endpoint string + // Model is a ai model accessible through the groq api. + Model string // ChatModel is the type for chat models present on the groq api. - ChatModel model + ChatModel Model // ModerationModel is the type for moderation models present on the groq api. - ModerationModel model + ModerationModel Model // AudioModel is the type for audio models present on the groq api. - AudioModel model -) - -const ( - chatCompletionsSuffix Endpoint = "/chat/completions" - transcriptionsSuffix Endpoint = "/audio/transcriptions" - translationsSuffix Endpoint = "/audio/translations" - embeddingsSuffix Endpoint = "/embeddings" - moderationsSuffix Endpoint = "/moderations" + AudioModel Model ) var ( @@ -207,6 +196,16 @@ var ( // - CreateChatCompletionStream // - CreateChatCompletionJSON ModelMixtral8X7B32768 ChatModel = "mixtral-8x7b-32768" + // ModelDistilWhisperLargeV3En is an AI audio transcription model. + // + // It is created/provided by Hugging Face. + // + // It has 448 context window. + // + // It can be used with the following client methods: + // - CreateTranscription + // - CreateTranslation + ModelDistilWhisperLargeV3En AudioModel = "distil-whisper-large-v3-en" // ModelWhisperLargeV3 is an AI audio transcription model. // // It is created/provided by OpenAI. @@ -217,6 +216,16 @@ var ( // - CreateTranscription // - CreateTranslation ModelWhisperLargeV3 AudioModel = "whisper-large-v3" + // ModelWhisperLargeV3Turbo is an AI audio transcription model. + // + // It is created/provided by OpenAI. + // + // It has 448 context window. + // + // It can be used with the following client methods: + // - CreateTranscription + // - CreateTranslation + ModelWhisperLargeV3Turbo AudioModel = "whisper-large-v3-turbo" // ModelLlamaGuard38B is an AI moderation model. // // It is created/provided by Meta. diff --git a/models_test.go b/pkg/models/models_test.go similarity index 63% rename from models_test.go rename to pkg/models/models_test.go index 1465410..a5a7501 100644 --- a/models_test.go +++ b/pkg/models/models_test.go @@ -1,19 +1,29 @@ // Code generated by groq-modeler DO NOT EDIT. // -// Created at: 2024-10-26 10:01:35 +// Created at: 2024-11-04 11:35:13 // // groq-modeler Version 1.1.2 -package groq +package models_test import ( + "bytes" "context" "os" "testing" "time" + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/moderation" + "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" + + _ "embed" ) +//go:embed testdata/whisper.mp3 +var whisperBytes []byte + // TestChatModelsGemma29BIt tests the Gemma29BIt model. // // It ensures that the model is supported by the groq-go library and the groq @@ -24,13 +34,15 @@ func TestChatModelsGemma29BIt(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelGemma29BIt, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelGemma29BIt, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -53,13 +65,15 @@ func TestChatModelsGemma7BIt(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelGemma7BIt, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelGemma7BIt, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -82,13 +96,15 @@ func TestChatModelsLlama3170BVersatile(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3170BVersatile, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3170BVersatile, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -111,13 +127,15 @@ func TestChatModelsLlama318BInstant(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama318BInstant, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama318BInstant, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -140,13 +158,15 @@ func TestChatModelsLlama3211BTextPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3211BTextPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3211BTextPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -169,13 +189,15 @@ func TestChatModelsLlama3211BVisionPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3211BVisionPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3211BVisionPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -198,13 +220,15 @@ func TestChatModelsLlama321BPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama321BPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama321BPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -227,13 +251,15 @@ func TestChatModelsLlama323BPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama323BPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama323BPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -256,13 +282,15 @@ func TestChatModelsLlama3290BTextPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3290BTextPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3290BTextPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -285,13 +313,15 @@ func TestChatModelsLlama3290BVisionPreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3290BVisionPreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3290BVisionPreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -314,13 +344,15 @@ func TestChatModelsLlama370B8192(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama370B8192, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama370B8192, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -343,13 +375,15 @@ func TestChatModelsLlama38B8192(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama38B8192, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama38B8192, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -372,13 +406,15 @@ func TestChatModelsLlama3Groq70B8192ToolUsePreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3Groq70B8192ToolUsePreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3Groq70B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -401,13 +437,15 @@ func TestChatModelsLlama3Groq8B8192ToolUsePreview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlama3Groq8B8192ToolUsePreview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlama3Groq8B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -430,13 +468,15 @@ func TestChatModelsLlavaV157B4096Preview(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelLlavaV157B4096Preview, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelLlavaV157B4096Preview, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -459,13 +499,15 @@ func TestChatModelsMixtral8X7B32768(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: ModelMixtral8X7B32768, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.ModelMixtral8X7B32768, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -478,6 +520,32 @@ func TestChatModelsMixtral8X7B32768(t *testing.T) { } } +// TestDistilWhisperLargeV3En tests the DistilWhisperLargeV3En transcription model. +// +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected with the api call using this transcription +// model. +func TestDistilWhisperLargeV3En(t *testing.T) { + if len(os.Getenv("UNIT")) < 1 { + t.Skip("Skipping DistilWhisperLargeV3En transcription test") + } + time.Sleep(time.Second * 5) + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) + a.NoError(err, "NewClient error") + reader := bytes.NewReader(whisperBytes) + response, err := client.CreateTranscription(ctx, groq.AudioRequest{ + Model: models.ModelWhisperLargeV3, + Reader: reader, + FilePath: "whisper.mp3", + }) + a.NoError(err, "CreateTranscription error") + a.NotEmpty(response.Text, "response.Text is empty for model WhisperLargeV3 calling CreateTranscription") +} + // TestWhisperLargeV3 tests the WhisperLargeV3 transcription model. // // It ensures that the model is supported by the groq-go library, the groq API, @@ -485,16 +553,46 @@ func TestChatModelsMixtral8X7B32768(t *testing.T) { // model. func TestWhisperLargeV3(t *testing.T) { if len(os.Getenv("UNIT")) < 1 { - t.Skip("Skipping WhisperLargeV3 test") + t.Skip("Skipping WhisperLargeV3 transcription test") } time.Sleep(time.Second * 5) a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateTranscription(ctx, AudioRequest{ - Model: ModelWhisperLargeV3, - FilePath: "./examples/audio-lex-fridman/The Roman Emperors who went insane Gregory Aldrete and Lex Fridman.mp3", + reader := bytes.NewReader(whisperBytes) + response, err := client.CreateTranscription(ctx, groq.AudioRequest{ + Model: models.ModelWhisperLargeV3, + Reader: reader, + FilePath: "whisper.mp3", + }) + a.NoError(err, "CreateTranscription error") + a.NotEmpty(response.Text, "response.Text is empty for model WhisperLargeV3 calling CreateTranscription") +} + +// TestWhisperLargeV3Turbo tests the WhisperLargeV3Turbo transcription model. +// +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected with the api call using this transcription +// model. +func TestWhisperLargeV3Turbo(t *testing.T) { + if len(os.Getenv("UNIT")) < 1 { + t.Skip("Skipping WhisperLargeV3Turbo transcription test") + } + time.Sleep(time.Second * 5) + a := assert.New(t) + ctx := context.Background() + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) + a.NoError(err, "NewClient error") + reader := bytes.NewReader(whisperBytes) + response, err := client.CreateTranscription(ctx, groq.AudioRequest{ + Model: models.ModelWhisperLargeV3, + Reader: reader, + FilePath: "whisper.mp3", }) a.NoError(err, "CreateTranscription error") a.NotEmpty(response.Text, "response.Text is empty for model WhisperLargeV3 calling CreateTranscription") @@ -506,26 +604,28 @@ func TestWhisperLargeV3(t *testing.T) { // and the operations are working as expected for the specific model type. func TestLlamaGuard38B(t *testing.T) { if len(os.Getenv("UNIT")) < 1 { - t.Skip("Skipping LlamaGuard38B test") + t.Skip("Skipping LlamaGuard38B moderation test") } time.Sleep(time.Second * 5) a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.Moderate(ctx, ModerationRequest{ - Model: ModelLlamaGuard38B, - Messages: []ChatCompletionMessage{ + response, err := client.Moderate(ctx, + []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "I want to kill them.", }, }, - }) + models.ModelLlamaGuard38B, + ) a.NoError(err, "Moderation error") a.Equal(true, response.Flagged) a.Contains( response.Categories, - CategoryViolentCrimes, + moderation.CategoryViolentCrimes, ) } diff --git a/pkg/models/testdata/whisper.mp3 b/pkg/models/testdata/whisper.mp3 new file mode 100644 index 0000000..922bcaf Binary files /dev/null and b/pkg/models/testdata/whisper.mp3 differ diff --git a/pkg/moderation/doc.go b/pkg/moderation/doc.go new file mode 100644 index 0000000..86ed8da --- /dev/null +++ b/pkg/moderation/doc.go @@ -0,0 +1,2 @@ +// Package moderation contains the types for content moderation. +package moderation diff --git a/pkg/moderation/moderations.go b/pkg/moderation/moderations.go new file mode 100644 index 0000000..4a64544 --- /dev/null +++ b/pkg/moderation/moderations.go @@ -0,0 +1,209 @@ +package moderation + +type ( + // HarmfulCategory is a category of harmful content that can be used to + // categorize a chat history using Llama Guard 3. + // + // [Llama Guard 3](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3/) + // + // Llama Guard 3 provides content safety support for the following + // languages: English, French, German, Hindi, Italian, Portuguese, + // Spanish, and Thai. + // + // string + HarmfulCategory string +) + +const ( + // CategoryViolentCrimes (S1) is the violent crimes category. + // + // Responses that enable, encourage, or endorse the commission of + // violent crimes, including: (1) unlawful violence toward people + // (ex: terrorism, genocide, murder, hate-motivated violent crimes, + // child abuse, assault, battery, kidnapping) and (2) unlawful violence + // toward animals (ex: animal abuse) + CategoryViolentCrimes HarmfulCategory = "violent_crimes" + // CategoryNonviolentCrimes (S2) is the non-violent crimes category. + CategoryNonviolentCrimes HarmfulCategory = "non_violent_crimes" + // CategorySexRelatedCrimes (S3) is the sex-related crimes category. + CategorySexRelatedCrimes HarmfulCategory = "sex_related_crimes" + // CategoryChildSexualExploitation (S4) is the child sexual exploitation + // category. + // + // Responses that contain, describe, enable, encourage, or endorse the + // sexual abuse of children. + CategoryChildSexualExploitation HarmfulCategory = "child_sexual_exploitation" + // CategoryDefamation (S5) is the defamation category. + // + // Signifies responses that contain, describe, enable, encourage, or + // endorse defamation. + CategoryDefamation HarmfulCategory = "defamation" + // CategorySpecializedAdvice (S6) is the specialized advice category. + // + // Signifies responses contain, describe, enable, encourage, or endorse + // specialized advice. + CategorySpecializedAdvice HarmfulCategory = "specialized_advice" + // CategoryPrivacy (S7) is the privacy category. + // + // Responses contain, describe, enable, encourage, or endorse privacy. + CategoryPrivacy HarmfulCategory = "privacy" + // CategoryIntellectualProperty (S8) is the intellectual property + // category. Responses that contain, describe, enable, encourage, or + // endorse intellectual property. + CategoryIntellectualProperty HarmfulCategory = "intellectual_property" + // CategoryIndiscriminateWeapons (S9) is the indiscriminate weapons + // category. + // + // Responses that contain, describe, enable, encourage, or endorse + // indiscriminate weapons. + CategoryIndiscriminateWeapons HarmfulCategory = "indiscriminate_weapons" + // CategoryHate (S10) is the hate category. + // + // Responses contain, describe, enable, encourage, or endorse hate. + CategoryHate HarmfulCategory = "hate" + // CategorySuicideAndSelfHarm (S11) is the suicide/self-harm category. + // + // Responses contain, describe, enable, encourage, or endorse suicide or + // self-harm. + CategorySuicideAndSelfHarm HarmfulCategory = "suicide_and_self_harm" + // CategorySexualContent (S12) is the sexual content category. + // + // Responses contain, describe, enable, encourage, or endorse + // sexual content. + CategorySexualContent HarmfulCategory = "sexual_content" + // CategoryElections (S13) is the elections category. + // + // Responses contain factually incorrect information about electoral + // systems and processes, including in the time, place, or manner of + // voting in civic elections. + CategoryElections HarmfulCategory = "elections" + // CategoryCodeInterpreterAbuse (S14) is the code interpreter abuse + // category. + // + // Responses that contain, describe, enable, encourage, or + // endorse code interpreter abuse. + CategoryCodeInterpreterAbuse HarmfulCategory = "code_interpreter_abuse" +) + +var ( + // SectionMap is a map of sections to categories. + // + // S1: Violent Crimes. + // + // AI models should not create content that enables, encourages, or + // excuses the commission of violent crimes. Examples of violent crimes + // include, but are not limited to: + // - Unlawful violence toward people (ex: terrorism, genocide, + // murder, hate crimes, child abuse, assault, battery, + // kidnapping) + // - Unlawful violence toward animals (ex: animal abuse) + // + // S2: Non-Violent Crimes. + // + // AI models should not create content that enables, encourages, or + // excuses the commission of non-violent crimes. Examples of non-violent + // crimes include, but are not limited to: + // - Financial crimes (ex: fraud, scams, money laundering) + // - Property crimes (ex: burglary, robbery, arson, vandalism) + // - Drug crimes (ex: creating or trafficking narcotics) + // - Weapons crimes (ex: producing unlicensed firearms) + // - Cyber crimes (ex: hacking, spyware, malware) + // + // S3: Sex Crimes. + // + // AI models should not create content that enables, encourages, or + // excuses the commission of sex crimes. + // Examples of sex crimes include, but are not limited to: + // - Human trafficking + // - Sexual assault (ex: rape) + // - Sexual harassment (ex: groping) + // - Lewd conduct + // + // S4: Child Exploitation. + // + // AI models should not create content that depicts child nudity or that + // enables, encourages, excuses, or depicts the sexual abuse of + // children. + // + // S5: Defamation. + // + // AI models should not create content about a real, living person that + // is verifiably false and likely to injure the person's reputation. + // + // S6: Specialized Advice. + // + // AI models should not create content that contains specialized + // financial, medical, or legal advice. + // + // S7: Privacy. + // + // AI models should not create content that contains sensitive, personal + // information about private individuals. + // + // S8: Intellectual Property. + // + // AI models should not create content that violates the intellectual + // property rights of third parties. + // + // S9: Indiscriminate Weapons. + // + // I models should not create content that enables the creation of + // weapons of mass destruction. Examples of such weapons include, but + // are not limited to: + // - Chemical weapons (ex: nerve gas) + // - Biological weapons (ex: anthrax) + // - Radiological weapons (ex: salted bombs) + // - Nuclear weapons (ex: atomic warheads) + // - High-yield explosive weapons (ex: cluster munitions) + // + // S10: Hate. + // + // AI models should not create content that is hateful toward people on + // the basis of their protected characteristics (race, color, ethnicity, + // national origin, disability, religious affiliation, caste, sexual + // orientation, sex, gender identity and serious disease). AI models + // should not create content that perpetuates negative stereotypes. + // + // S11: Self-Harm. + // + // AI models should not create content that enables, encourages, or + // excuses acts of intentional self-harm. Examples of such actions + // include, but are not limited to: + // - Suicide + // - Self-injury (ex: cutting) + // - Disordered eating + // S12: Sexual Content. + // + // AI models should not create depictions of nude adults or content that + // contains erotic descriptions or explicit depictions of sex acts. + // + // S13: Elections. + // + // AI models should not create content that contain factually incorrect + // information about electoral systems and processes, including in the + // time, place, or manner of voting in civic elections. + // + // S14: Code Interpreter Abuse. + // + // AI models should not create content that attempts to abuse code + // interpreters. Examples of code interpreter abuse include, but are not + // limited to: + // - Denial of service attacks + // - Container escapes or privilege escalation. + SectionMap = map[string]HarmfulCategory{ + "S1": CategoryViolentCrimes, + "S2": CategoryNonviolentCrimes, + "S3": CategorySexRelatedCrimes, + "S4": CategoryChildSexualExploitation, + "S5": CategoryDefamation, + "S6": CategorySpecializedAdvice, + "S7": CategoryPrivacy, + "S8": CategoryIntellectualProperty, + "S9": CategoryIndiscriminateWeapons, + "S10": CategoryHate, + "S11": CategorySuicideAndSelfHarm, + "S12": CategorySexualContent, + "S13": CategoryElections, + "S14": CategoryCodeInterpreterAbuse, + } +) diff --git a/pkg/schema/doc.go b/pkg/schema/doc.go new file mode 100644 index 0000000..9765f4f --- /dev/null +++ b/pkg/schema/doc.go @@ -0,0 +1,2 @@ +// Package schema provides an interface for working with JSON Schemas. +package schema diff --git a/pkg/schema/reflector.go b/pkg/schema/reflector.go new file mode 100644 index 0000000..c468209 --- /dev/null +++ b/pkg/schema/reflector.go @@ -0,0 +1,519 @@ +package schema + +import "reflect" + +type ( + + // A reflector reflects values into a Schema. + reflector struct { + // BaseSchemaID defines the URI that will be used as a base to determine + // Schema IDs for models. For example, a base Schema ID of ` + // https://conneroh.com/schemas` when defined with a struct called + // `User{}`, will result in a schema with an ID set to + // `https://conneroh.com/schemas/user`. + // + // If no `BaseSchemaID` is provided, we'll take the type's complete + // package path and use that as a base instead. Set `Anonymous` to try + // if you do not want to include a schema ID. + BaseSchemaID schemaID + // Anonymous when true will hide the auto-generated Schema ID and + // provide what is known as an "anonymous schema". As a rule, this is + // not recommended. + Anonymous bool + // AssignAnchor when true will use the original struct's name as an + // anchor inside every definition, including the root schema. These can + // be useful for having a reference to the original struct's name in + // CamelCase instead of the snake-case used + // by default for URI compatibility. + // + // Anchors do not appear to be widely used out in the wild, so at this + // time the anchors themselves will not be used inside generated schema. + AssignAnchor bool + // AllowAdditionalProperties will cause the Reflector to generate a + // schema without additionalProperties set to 'false' for all struct + // types. This means the presence of additional keys in JSON objects + // will not cause validation to fail. Note said additional keys will + // simply be dropped when the validated JSON is unmarshaled. + AllowAdditionalProperties bool + // RequiredFromJSONSchemaTags will cause the Reflector to generate a + // schema that requires any key tagged with `jsonschema:required`, + // overriding the default of requiring any key *not* tagged with + // `json:,omitempty`. + RequiredFromJSONSchemaTags bool + // Do not reference definitions. This will remove the top-level $defs + // map and instead cause the entire structure of types to be output in + // one tree. The list of type definitions (`$defs`) will not be + // included. + DoNotReference bool + // ExpandedStruct when true will include the reflected type's definition + // in the root as opposed to a definition with a reference. + ExpandedStruct bool + // FieldNameTag will change the tag used to get field names. json tags + // are used by default. + FieldNameTag string + // IgnoredTypes defines a slice of types that should be ignored in the + // schema, switching to just allowing additional properties instead. + IgnoredTypes []any + // Lookup allows a function to be defined that will provide a custom + // mapping of types to Schema IDs. This allows existing schema documents + // to be referenced by their ID instead of being embedded into the + // current schema definitions. Reflected types will never be pointers, + // only underlying elements. + Lookup func(reflect.Type) schemaID + // Mapper is a function that can be used to map custom Go types to + // jsonschema schemas. + Mapper func(reflect.Type) *Schema + // Namer allows customizing of type names. The default is to use the + // type's name provided by the reflect package. + Namer func(reflect.Type) string + // KeyNamer allows customizing of key names. + // The default is to use the key's name as is, or the json tag if + // present. + // + // If a json tag is present, KeyNamer will receive the tag's name as an + // argument, not the original key name. + KeyNamer func(string) string + // AdditionalFields allows adding structfields for a given type + AdditionalFields func(reflect.Type) []reflect.StructField + // CommentMap is a dictionary of fully qualified go types and fields to + // comment strings that will be used if a description has not already + // been provided in the tags. Types and fields are added to the package + // path using "." as a separator. + // + // Type descriptions should be defined like: + // + // map[string]string{"github.com/conneroisu/groq.Reflector": "A Reflector reflects values into a Schema."} + // + // And Fields defined as: + // + // map[string]string{"github.com/conneroisu/groq.Reflector.DoNotReference": "Do not reference definitions."} + // + // See also: AddGoComments + CommentMap map[string]string + } +) + +// Reflect reflects to Schema from a value. +func (r *reflector) Reflect(v any) *Schema { + return r.ReflectFromType(reflect.TypeOf(v)) +} + +// ReflectFromType generates root schema +func (r *reflector) ReflectFromType(t reflect.Type) *Schema { + if t.Kind() == reflect.Ptr { + t = t.Elem() // re-assign from pointer + } + name := r.typeName(t) + s := new(Schema) + definitions := schemaDefinitions{} + s.Definitions = definitions + bs := r.reflectTypeToSchemaWithID(definitions, t) + if r.ExpandedStruct { + *s = *definitions[name] + delete(definitions, name) + } else { + *s = *bs + } + // Attempt to set the schema ID + if !r.Anonymous && s.ID == EmptyID { + baseSchemaID := r.BaseSchemaID + if baseSchemaID == EmptyID { + i := schemaID("https://" + t.PkgPath()) + if err := i.Validate(); err == nil { + // it's okay to silently ignore URL errors + baseSchemaID = i + } + } + if baseSchemaID != EmptyID { + s.ID = baseSchemaID.Add(ToSnakeCase(name)) + } + } + s.Version = version + if !r.DoNotReference { + s.Definitions = definitions + } + return s +} + +// SetBaseSchemaID is a helper use to be able to set the reflectors base +// schema ID from a string as opposed to then ID instance. +func (r *reflector) SetBaseSchemaID(identifier string) { + r.BaseSchemaID = schemaID(identifier) +} +func (r *reflector) refOrReflectTypeToSchema( + definitions schemaDefinitions, + t reflect.Type, +) *Schema { + id := r.lookupID(t) + if id != EmptyID { + return &Schema{ + Ref: string(id), + } + } + // Already added to definitions? + if def := r.refDefinition(definitions, t); def != nil { + return def + } + return r.reflectTypeToSchemaWithID(definitions, t) +} +func (r *reflector) reflectTypeToSchemaWithID( + defs schemaDefinitions, + t reflect.Type, +) *Schema { + s := r.reflectTypeToSchema(defs, t) + if s != nil { + if r.Lookup != nil { + identifier := r.Lookup(t) + if identifier != EmptyID { + s.ID = identifier + } + } + } + return s +} +func (r *reflector) reflectTypeToSchema( + definitions schemaDefinitions, + t reflect.Type, +) *Schema { + // only try to reflect non-pointers + if t.Kind() == reflect.Ptr { + return r.refOrReflectTypeToSchema(definitions, t.Elem()) + } + // Check if the there is an alias method that provides an object + // that we should use instead of this one. + if t.Implements(customAliasSchema) { + v := reflect.New(t) + o := v.Interface().(aliasSchemaImpl) + t = reflect.TypeOf(o.JSONSchemaAlias()) + return r.refOrReflectTypeToSchema(definitions, t) + } + // Do any pre-definitions exist? + if r.Mapper != nil { + if t := r.Mapper(t); t != nil { + return t + } + } + if rt := r.reflectCustomSchema(definitions, t); rt != nil { + return rt + } + // Prepare a base to which details can be added + st := new(Schema) + // jsonpb will marshal protobuf enum options as either strings or integers. + // It will unmarshal either. + if t.Implements(protoEnumType) { + st.OneOf = []*Schema{ + {Type: "string"}, + {Type: "integer"}, + } + return st + } + // Defined format types for JSON Schema Validation + // RFC draft-wright-json-schema-validation-00, section 7.3 + // TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7 + if t == ipType { + // TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + st.Type = "string" + st.Format = "ipv4" + return st + } + switch t.Kind() { + case reflect.Struct: + r.reflectStruct(definitions, t, st) + case reflect.Slice, reflect.Array: + r.reflectSliceOrArray(definitions, t, st) + case reflect.Map: + r.reflectMap(definitions, t, st) + case reflect.Interface: + // empty + case reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64: + st.Type = "integer" + case reflect.Float32, reflect.Float64: + st.Type = "number" + case reflect.Bool: + st.Type = "boolean" + case reflect.String: + st.Type = "string" + default: + panic("unsupported type " + t.String()) + } + r.reflectSchemaExtend(definitions, t, st) + // Always try to reference the definition which may have just been created + if def := r.refDefinition(definitions, t); def != nil { + return def + } + return st +} +func (r *reflector) reflectCustomSchema( + definitions schemaDefinitions, + t reflect.Type, +) *Schema { + if t.Kind() == reflect.Ptr { + return r.reflectCustomSchema(definitions, t.Elem()) + } + if t.Implements(customType) { + v := reflect.New(t) + o := v.Interface().(customSchemaImpl) + st := o.JSONSchema() + r.addDefinition(definitions, t, st) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + return st + } + return nil +} +func (r *reflector) reflectSchemaExtend( + definitions schemaDefinitions, + t reflect.Type, + s *Schema, +) *Schema { + if t.Implements(extendType) { + v := reflect.New(t) + o := v.Interface().(extendSchemaImpl) + o.JSONSchemaExtend(s) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + } + return s +} +func (r *reflector) reflectSliceOrArray( + definitions schemaDefinitions, + t reflect.Type, + st *Schema, +) { + if t == rawMessageType { + return + } + r.addDefinition(definitions, t, st) + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + if t.Kind() == reflect.Array { + l := uint64(t.Len()) + st.MinItems = &l + st.MaxItems = &l + } + if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() { + st.Type = "string" + st.ContentEncoding = "base64" + return + } + st.Type = "array" + st.Items = r.refOrReflectTypeToSchema(definitions, t.Elem()) +} +func (r *reflector) reflectMap( + definitions schemaDefinitions, + t reflect.Type, + st *Schema, +) { + r.addDefinition(definitions, t, st) + st.Type = "object" + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + switch t.Key().Kind() { + case reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64: + st.PatternProperties = map[string]*Schema{ + "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), + } + st.AdditionalProperties = falseSchema + return + } + if t.Elem().Kind() != reflect.Interface { + st.AdditionalProperties = r.refOrReflectTypeToSchema( + definitions, + t.Elem(), + ) + } +} + +// Reflects a struct to a JSON Schema type. +func (r *reflector) reflectStruct( + definitions schemaDefinitions, + t reflect.Type, + s *Schema, +) { + // Handle special types + switch t { + case timeType: // date-time RFC section 7.3.1 + s.Type = "string" + s.Format = "date-time" + return + case uriType: // uri RFC section 7.3.6 + s.Type = "string" + s.Format = "uri" + return + } + r.addDefinition(definitions, t, s) + s.Type = "object" + s.Properties = newProperties() + s.Description = r.lookupComment(t, "") + if r.AssignAnchor { + s.Anchor = t.Name() + } + if !r.AllowAdditionalProperties && s.AdditionalProperties == nil { + s.AdditionalProperties = falseSchema + } + ignored := false + for _, it := range r.IgnoredTypes { + if reflect.TypeOf(it) == t { + ignored = true + break + } + } + if !ignored { + r.reflectStructFields(s, definitions, t) + } +} + +func (r *reflector) reflectStructFields( + st *Schema, + definitions schemaDefinitions, + t reflect.Type, +) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return + } + var getFieldDocString customGetFieldDocString + if t.Implements(customStructGetFieldDocString) { + v := reflect.New(t) + o := v.Interface().(customSchemaGetFieldDocString) + getFieldDocString = o.GetFieldDocString + } + customPropertyMethod := func(string) any { + return nil + } + if t.Implements(customPropertyAliasSchema) { + v := reflect.New(t) + o := v.Interface().(propertyAliasSchemaImpl) + customPropertyMethod = o.JSONSchemaProperty + } + handleField := func(f reflect.StructField) { + name, shouldEmbed, required, nullable := r.reflectFieldName(f) + // if anonymous and exported type should be processed + // recursively current type should inherit properties of + // anonymous one + if name == "" { + if shouldEmbed { + r.reflectStructFields(st, definitions, f.Type) + } + return + } + // If a JSONSchemaAlias(prop string) method is defined, attempt + // to use the provided object's type instead of the field's + // type. + var property *Schema + if alias := customPropertyMethod(name); alias != nil { + property = r.refOrReflectTypeToSchema( + definitions, + reflect.TypeOf(alias), + ) + } else { + property = r.refOrReflectTypeToSchema(definitions, f.Type) + } + property.fieldsFromTags(f, st, name) + if property.Description == "" { + property.Description = r.lookupComment(t, f.Name) + } + if getFieldDocString != nil { + property.Description = getFieldDocString(f.Name) + } + if nullable { + property = &Schema{ + OneOf: []*Schema{ + property, + { + Type: "null", + }, + }, + } + } + st.Properties.Set(name, property) + if required { + st.Required = appendUniqueString(st.Required, name) + } + } + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + handleField(f) + } + if r.AdditionalFields != nil { + if af := r.AdditionalFields(t); af != nil { + for _, sf := range af { + handleField(sf) + } + } + } +} + +func (r *reflector) lookupComment(t reflect.Type, name string) string { + if r.CommentMap == nil { + return "" + } + n := fullyQualifiedTypeName(t) + if name != "" { + n = n + "." + name + } + return r.CommentMap[n] +} + +// addDefinition will append the provided schema. If needed, an ID and anchor +// will also be added. +func (r *reflector) addDefinition( + definitions schemaDefinitions, + t reflect.Type, + s *Schema, +) { + name := r.typeName(t) + if name == "" { + return + } + definitions[name] = s +} + +// refDefinition will provide a schema with a reference to an existing +// definition. +func (r *reflector) refDefinition( + definitions schemaDefinitions, + t reflect.Type, +) *Schema { + if r.DoNotReference { + return nil + } + name := r.typeName(t) + if name == "" { + return nil + } + if _, ok := definitions[name]; !ok { + return nil + } + return &Schema{ + Ref: "#/$defs/" + name, + } +} +func (r *reflector) lookupID(t reflect.Type) schemaID { + if r.Lookup != nil { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return r.Lookup(t) + } + return EmptyID +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 57890ff..94ed0f7 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -703,8 +703,6 @@ type ( // // https://datatracker.ietf.org/doc/html/draft-bhutton-json-schema-validation-00#section-7.3 // - // TODO: add type for format and all the possible formats - // // The value of this field MUST be a string. Implementations that // use a subset of JSON as their input format, such as JSON Hyper-Schema // or JSON Schema Hyper-Schema, MAY implement validation against @@ -885,468 +883,13 @@ type ( GetFieldDocString(fieldName string) string } customGetFieldDocString func(fieldName string) string - // A reflector reflects values into a Schema. - reflector struct { - // BaseSchemaID defines the URI that will be used as a base to determine - // Schema IDs for models. For example, a base Schema ID of ` - // https://conneroh.com/schemas` when defined with a struct called - // `User{}`, will result in a schema with an ID set to - // `https://conneroh.com/schemas/user`. - // - // If no `BaseSchemaID` is provided, we'll take the type's complete - // package path and use that as a base instead. Set `Anonymous` to try - // if you do not want to include a schema ID. - BaseSchemaID schemaID - // Anonymous when true will hide the auto-generated Schema ID and - // provide what is known as an "anonymous schema". As a rule, this is - // not recommended. - Anonymous bool - // AssignAnchor when true will use the original struct's name as an - // anchor inside every definition, including the root schema. These can - // be useful for having a reference to the original struct's name in - // CamelCase instead of the snake-case used - // by default for URI compatibility. - // - // Anchors do not appear to be widely used out in the wild, so at this - // time the anchors themselves will not be used inside generated schema. - AssignAnchor bool - // AllowAdditionalProperties will cause the Reflector to generate a - // schema without additionalProperties set to 'false' for all struct - // types. This means the presence of additional keys in JSON objects - // will not cause validation to fail. Note said additional keys will - // simply be dropped when the validated JSON is unmarshaled. - AllowAdditionalProperties bool - // RequiredFromJSONSchemaTags will cause the Reflector to generate a - // schema that requires any key tagged with `jsonschema:required`, - // overriding the default of requiring any key *not* tagged with - // `json:,omitempty`. - RequiredFromJSONSchemaTags bool - // Do not reference definitions. This will remove the top-level $defs - // map and instead cause the entire structure of types to be output in - // one tree. The list of type definitions (`$defs`) will not be - // included. - DoNotReference bool - // ExpandedStruct when true will include the reflected type's definition - // in the root as opposed to a definition with a reference. - ExpandedStruct bool - // FieldNameTag will change the tag used to get field names. json tags - // are used by default. - FieldNameTag string - // IgnoredTypes defines a slice of types that should be ignored in the - // schema, switching to just allowing additional properties instead. - IgnoredTypes []any - // Lookup allows a function to be defined that will provide a custom - // mapping of types to Schema IDs. This allows existing schema documents - // to be referenced by their ID instead of being embedded into the - // current schema definitions. Reflected types will never be pointers, - // only underlying elements. - Lookup func(reflect.Type) schemaID - // Mapper is a function that can be used to map custom Go types to - // jsonschema schemas. - Mapper func(reflect.Type) *Schema - // Namer allows customizing of type names. The default is to use the - // type's name provided by the reflect package. - Namer func(reflect.Type) string - // KeyNamer allows customizing of key names. - // The default is to use the key's name as is, or the json tag if - // present. - // - // If a json tag is present, KeyNamer will receive the tag's name as an - // argument, not the original key name. - KeyNamer func(string) string - // AdditionalFields allows adding structfields for a given type - AdditionalFields func(reflect.Type) []reflect.StructField - // CommentMap is a dictionary of fully qualified go types and fields to - // comment strings that will be used if a description has not already - // been provided in the tags. Types and fields are added to the package - // path using "." as a separator. - // - // Type descriptions should be defined like: - // - // map[string]string{"github.com/conneroisu/groq.Reflector": "A Reflector reflects values into a Schema."} - // - // And Fields defined as: - // - // map[string]string{"github.com/conneroisu/groq.Reflector.DoNotReference": "Do not reference definitions."} - // - // See also: AddGoComments - CommentMap map[string]string - } ) -// Reflect reflects to Schema from a value. -func (r *reflector) Reflect(v any) *Schema { - return r.ReflectFromType(reflect.TypeOf(v)) -} - -// ReflectFromType generates root schema -func (r *reflector) ReflectFromType(t reflect.Type) *Schema { - if t.Kind() == reflect.Ptr { - t = t.Elem() // re-assign from pointer - } - name := r.typeName(t) - s := new(Schema) - definitions := schemaDefinitions{} - s.Definitions = definitions - bs := r.reflectTypeToSchemaWithID(definitions, t) - if r.ExpandedStruct { - *s = *definitions[name] - delete(definitions, name) - } else { - *s = *bs - } - // Attempt to set the schema ID - if !r.Anonymous && s.ID == EmptyID { - baseSchemaID := r.BaseSchemaID - if baseSchemaID == EmptyID { - i := schemaID("https://" + t.PkgPath()) - if err := i.Validate(); err == nil { - // it's okay to silently ignore URL errors - baseSchemaID = i - } - } - if baseSchemaID != EmptyID { - s.ID = baseSchemaID.Add(ToSnakeCase(name)) - } - } - s.Version = version - if !r.DoNotReference { - s.Definitions = definitions - } - return s -} - // Go code generated from protobuf enum types should fulfil this interface. type protoEnum interface { EnumDescriptor() ([]byte, []int) } -// SetBaseSchemaID is a helper use to be able to set the reflectors base -// schema ID from a string as opposed to then ID instance. -func (r *reflector) SetBaseSchemaID(identifier string) { - r.BaseSchemaID = schemaID(identifier) -} -func (r *reflector) refOrReflectTypeToSchema( - definitions schemaDefinitions, - t reflect.Type, -) *Schema { - id := r.lookupID(t) - if id != EmptyID { - return &Schema{ - Ref: string(id), - } - } - // Already added to definitions? - if def := r.refDefinition(definitions, t); def != nil { - return def - } - return r.reflectTypeToSchemaWithID(definitions, t) -} -func (r *reflector) reflectTypeToSchemaWithID( - defs schemaDefinitions, - t reflect.Type, -) *Schema { - s := r.reflectTypeToSchema(defs, t) - if s != nil { - if r.Lookup != nil { - identifier := r.Lookup(t) - if identifier != EmptyID { - s.ID = identifier - } - } - } - return s -} -func (r *reflector) reflectTypeToSchema( - definitions schemaDefinitions, - t reflect.Type, -) *Schema { - // only try to reflect non-pointers - if t.Kind() == reflect.Ptr { - return r.refOrReflectTypeToSchema(definitions, t.Elem()) - } - // Check if the there is an alias method that provides an object - // that we should use instead of this one. - if t.Implements(customAliasSchema) { - v := reflect.New(t) - o := v.Interface().(aliasSchemaImpl) - t = reflect.TypeOf(o.JSONSchemaAlias()) - return r.refOrReflectTypeToSchema(definitions, t) - } - // Do any pre-definitions exist? - if r.Mapper != nil { - if t := r.Mapper(t); t != nil { - return t - } - } - if rt := r.reflectCustomSchema(definitions, t); rt != nil { - return rt - } - // Prepare a base to which details can be added - st := new(Schema) - // jsonpb will marshal protobuf enum options as either strings or integers. - // It will unmarshal either. - if t.Implements(protoEnumType) { - st.OneOf = []*Schema{ - {Type: "string"}, - {Type: "integer"}, - } - return st - } - // Defined format types for JSON Schema Validation - // RFC draft-wright-json-schema-validation-00, section 7.3 - // TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7 - if t == ipType { - // TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5 - st.Type = "string" - st.Format = "ipv4" - return st - } - switch t.Kind() { - case reflect.Struct: - r.reflectStruct(definitions, t, st) - case reflect.Slice, reflect.Array: - r.reflectSliceOrArray(definitions, t, st) - case reflect.Map: - r.reflectMap(definitions, t, st) - case reflect.Interface: - // empty - case reflect.Int, - reflect.Int8, - reflect.Int16, - reflect.Int32, - reflect.Int64, - reflect.Uint, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64: - st.Type = "integer" - case reflect.Float32, reflect.Float64: - st.Type = "number" - case reflect.Bool: - st.Type = "boolean" - case reflect.String: - st.Type = "string" - default: - panic("unsupported type " + t.String()) - } - r.reflectSchemaExtend(definitions, t, st) - // Always try to reference the definition which may have just been created - if def := r.refDefinition(definitions, t); def != nil { - return def - } - return st -} -func (r *reflector) reflectCustomSchema( - definitions schemaDefinitions, - t reflect.Type, -) *Schema { - if t.Kind() == reflect.Ptr { - return r.reflectCustomSchema(definitions, t.Elem()) - } - if t.Implements(customType) { - v := reflect.New(t) - o := v.Interface().(customSchemaImpl) - st := o.JSONSchema() - r.addDefinition(definitions, t, st) - if ref := r.refDefinition(definitions, t); ref != nil { - return ref - } - return st - } - return nil -} -func (r *reflector) reflectSchemaExtend( - definitions schemaDefinitions, - t reflect.Type, - s *Schema, -) *Schema { - if t.Implements(extendType) { - v := reflect.New(t) - o := v.Interface().(extendSchemaImpl) - o.JSONSchemaExtend(s) - if ref := r.refDefinition(definitions, t); ref != nil { - return ref - } - } - return s -} -func (r *reflector) reflectSliceOrArray( - definitions schemaDefinitions, - t reflect.Type, - st *Schema, -) { - if t == rawMessageType { - return - } - r.addDefinition(definitions, t, st) - if st.Description == "" { - st.Description = r.lookupComment(t, "") - } - if t.Kind() == reflect.Array { - l := uint64(t.Len()) - st.MinItems = &l - st.MaxItems = &l - } - if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() { - st.Type = "string" - st.ContentEncoding = "base64" - return - } - st.Type = "array" - st.Items = r.refOrReflectTypeToSchema(definitions, t.Elem()) -} -func (r *reflector) reflectMap( - definitions schemaDefinitions, - t reflect.Type, - st *Schema, -) { - r.addDefinition(definitions, t, st) - st.Type = "object" - if st.Description == "" { - st.Description = r.lookupComment(t, "") - } - switch t.Key().Kind() { - case reflect.Int, - reflect.Int8, - reflect.Int16, - reflect.Int32, - reflect.Int64: - st.PatternProperties = map[string]*Schema{ - "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), - } - st.AdditionalProperties = falseSchema - return - } - if t.Elem().Kind() != reflect.Interface { - st.AdditionalProperties = r.refOrReflectTypeToSchema( - definitions, - t.Elem(), - ) - } -} - -// Reflects a struct to a JSON Schema type. -func (r *reflector) reflectStruct( - definitions schemaDefinitions, - t reflect.Type, - s *Schema, -) { - // Handle special types - switch t { - case timeType: // date-time RFC section 7.3.1 - s.Type = "string" - s.Format = "date-time" - return - case uriType: // uri RFC section 7.3.6 - s.Type = "string" - s.Format = "uri" - return - } - r.addDefinition(definitions, t, s) - s.Type = "object" - s.Properties = newProperties() - s.Description = r.lookupComment(t, "") - if r.AssignAnchor { - s.Anchor = t.Name() - } - if !r.AllowAdditionalProperties && s.AdditionalProperties == nil { - s.AdditionalProperties = falseSchema - } - ignored := false - for _, it := range r.IgnoredTypes { - if reflect.TypeOf(it) == t { - ignored = true - break - } - } - if !ignored { - r.reflectStructFields(s, definitions, t) - } -} -func (r *reflector) reflectStructFields( - st *Schema, - definitions schemaDefinitions, - t reflect.Type, -) { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.Kind() != reflect.Struct { - return - } - var getFieldDocString customGetFieldDocString - if t.Implements(customStructGetFieldDocString) { - v := reflect.New(t) - o := v.Interface().(customSchemaGetFieldDocString) - getFieldDocString = o.GetFieldDocString - } - customPropertyMethod := func(string) any { - return nil - } - if t.Implements(customPropertyAliasSchema) { - v := reflect.New(t) - o := v.Interface().(propertyAliasSchemaImpl) - customPropertyMethod = o.JSONSchemaProperty - } - handleField := func(f reflect.StructField) { - name, shouldEmbed, required, nullable := r.reflectFieldName(f) - // if anonymous and exported type should be processed - // recursively current type should inherit properties of - // anonymous one - if name == "" { - if shouldEmbed { - r.reflectStructFields(st, definitions, f.Type) - } - return - } - // If a JSONSchemaAlias(prop string) method is defined, attempt - // to use the provided object's type instead of the field's - // type. - var property *Schema - if alias := customPropertyMethod(name); alias != nil { - property = r.refOrReflectTypeToSchema( - definitions, - reflect.TypeOf(alias), - ) - } else { - property = r.refOrReflectTypeToSchema(definitions, f.Type) - } - property.fieldsFromTags(f, st, name) - if property.Description == "" { - property.Description = r.lookupComment(t, f.Name) - } - if getFieldDocString != nil { - property.Description = getFieldDocString(f.Name) - } - if nullable { - property = &Schema{ - OneOf: []*Schema{ - property, - { - Type: "null", - }, - }, - } - } - st.Properties.Set(name, property) - if required { - st.Required = appendUniqueString(st.Required, name) - } - } - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - handleField(f) - } - if r.AdditionalFields != nil { - if af := r.AdditionalFields(t); af != nil { - for _, sf := range af { - handleField(sf) - } - } - } -} func appendUniqueString(base []string, value string) []string { for _, v := range base { if v == value { @@ -1355,60 +898,6 @@ func appendUniqueString(base []string, value string) []string { } return append(base, value) } -func (r *reflector) lookupComment(t reflect.Type, name string) string { - if r.CommentMap == nil { - return "" - } - n := fullyQualifiedTypeName(t) - if name != "" { - n = n + "." + name - } - return r.CommentMap[n] -} - -// addDefinition will append the provided schema. If needed, an ID and anchor -// will also be added. -func (r *reflector) addDefinition( - definitions schemaDefinitions, - t reflect.Type, - s *Schema, -) { - name := r.typeName(t) - if name == "" { - return - } - definitions[name] = s -} - -// refDefinition will provide a schema with a reference to an existing -// definition. -func (r *reflector) refDefinition( - definitions schemaDefinitions, - t reflect.Type, -) *Schema { - if r.DoNotReference { - return nil - } - name := r.typeName(t) - if name == "" { - return nil - } - if _, ok := definitions[name]; !ok { - return nil - } - return &Schema{ - Ref: "#/$defs/" + name, - } -} -func (r *reflector) lookupID(t reflect.Type) schemaID { - if r.Lookup != nil { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return r.Lookup(t) - } - return EmptyID -} func (t *Schema) fieldsFromTags( f reflect.StructField, parent *Schema, diff --git a/pkg/streams/doc.go b/pkg/streams/doc.go new file mode 100644 index 0000000..3d90a82 --- /dev/null +++ b/pkg/streams/doc.go @@ -0,0 +1,2 @@ +// Package streams contains the interfaces for groq-go streamed responses. +package streams diff --git a/pkg/streams/stream.go b/pkg/streams/stream.go new file mode 100644 index 0000000..fb9ce8d --- /dev/null +++ b/pkg/streams/stream.go @@ -0,0 +1,160 @@ +package streams + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/conneroisu/groq-go/pkg/groqerr" +) + +type ( + // Streamer is an interface for a Streamer. + Streamer[T any] interface { + *T + } + // DefaultErrorAccumulator is a default implementation of ErrorAccumulator + DefaultErrorAccumulator struct { + Buffer groqerr.ErrorBuffer + } + // StreamReader is a stream reader. + StreamReader[T any] struct { + emptyMessagesLimit uint + isFinished bool + Reader *bufio.Reader + readCloser io.ReadCloser + ErrAccumulator ErrorAccumulator + Header http.Header // Header is the header of the response. + } + // ErrorAccumulator is an interface for a unit that accumulates errors. + ErrorAccumulator interface { + // Write method writes bytes to the error accumulator + // + // It implements the io.Writer interface. + Write(p []byte) error + // Bytes method returns the bytes of the error accumulator. + Bytes() []byte + } +) + +// Recv receives a response from the stream. +func (stream *StreamReader[T]) Recv() (response T, err error) { + if stream.isFinished { + err = io.EOF + return response, err + } + return stream.processLines() +} + +// processLines processes the lines of the current response in the stream. +func (stream *StreamReader[T]) processLines() (T, error) { + var ( + headerData = []byte("data: ") + errorPrefix = []byte(`data: {"error":`) + emptyMessagesCount uint + hasErrorPrefix bool + ) + for { + rawLine, err := stream.Reader.ReadBytes('\n') + if err != nil || hasErrorPrefix { + respErr := stream.UnmarshalError() + if respErr != nil { + return *new(T), + fmt.Errorf("error, %w", respErr.Error) + } + return *new(T), err + } + noSpaceLine := bytes.TrimSpace(rawLine) + if bytes.HasPrefix(noSpaceLine, errorPrefix) { + hasErrorPrefix = true + } + if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if hasErrorPrefix { + noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + } + err := stream.ErrAccumulator.Write(noSpaceLine) + if err != nil { + return *new(T), err + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + return *new(T), groqerr.ErrTooManyEmptyStreamMessages{} + } + continue + } + noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + if string(noPrefixLine) == "[DONE]" { + stream.isFinished = true + return *new(T), io.EOF + } + var response T + unmarshalErr := json.Unmarshal(noPrefixLine, &response) + if unmarshalErr != nil { + return *new(T), unmarshalErr + } + return response, nil + } +} + +// UnmarshalError unmarshals the error response. +func (stream *StreamReader[T]) UnmarshalError() (errResp *groqerr.ErrorResponse) { + errBytes := stream.ErrAccumulator.Bytes() + if len(errBytes) == 0 { + return + } + err := json.Unmarshal(errBytes, &errResp) + if err != nil { + errResp = nil + } + return +} + +// Close closes the stream. +func (stream *StreamReader[T]) Close() error { + return stream.readCloser.Close() +} + +// NewErrorAccumulator creates a new error accumulator +func NewErrorAccumulator() ErrorAccumulator { + return &DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } +} + +// Write method writes bytes to the error accumulator. +func (e *DefaultErrorAccumulator) Write(p []byte) error { + _, err := e.Buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +// Bytes method returns the bytes of the error accumulator. +func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { + if e.Buffer.Len() == 0 { + return + } + errBytes = e.Buffer.Bytes() + return +} + +// NewStreamReader creates a new stream reader. +func NewStreamReader[Q any, T Streamer[Q]]( + readCloser io.ReadCloser, + header map[string][]string, + emptyMessagesLimit uint, +) *StreamReader[T] { + stream := &StreamReader[T]{ + emptyMessagesLimit: emptyMessagesLimit, + isFinished: false, + Header: header, + Reader: bufio.NewReader(readCloser), + readCloser: readCloser, + ErrAccumulator: NewErrorAccumulator(), + } + return stream +} diff --git a/pkg/streams/stream_test.go b/pkg/streams/stream_test.go new file mode 100644 index 0000000..fb83d37 --- /dev/null +++ b/pkg/streams/stream_test.go @@ -0,0 +1,151 @@ +package streams_test + +import ( + "bytes" + "errors" + "io" + "net/http" + "testing" + + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/groqerr" + "github.com/conneroisu/groq-go/pkg/streams" + "github.com/conneroisu/groq-go/pkg/test" + "github.com/stretchr/testify/assert" +) + +// TestStreamReaderReturnsUnmarshalerErrors tests the stream reader returns an unmarshaler error. +func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { + stream := &streams.StreamReader[groq.ChatCompletionStreamResponse]{ + ErrAccumulator: streams.NewErrorAccumulator(), + } + + respErr := stream.UnmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil with empty buffer: %v", respErr) + } + + err := stream.ErrAccumulator.Write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + + respErr = stream.UnmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + } +} + +// TestStreamReaderReturnsErrTooManyEmptyStreamMessages tests the stream reader returns an error when the stream has too many empty messages. +func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { + a := assert.New(t) + reader := &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("\n\n\n\n"))), + } + stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse]( + reader.Body, + map[string][]string{ + "Content-Type": {"text/event-stream"}, + }, + 3, + ) + _, err := stream.Recv() + a.ErrorIs( + err, + groqerr.ErrTooManyEmptyStreamMessages{}, + "Did not return error when recv failed", + err.Error(), + ) +} + +// TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed tests the stream reader returns an error when the error accumulator fails to write. +func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { + a := assert.New(t) + reader := &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("\n"))), + } + stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse]( + reader.Body, + map[string][]string{ + "Content-Type": {"text/event-stream"}, + }, + 0, + ) + _, err := stream.Recv() + a.ErrorIs( + err, + groqerr.ErrTooManyEmptyStreamMessages{}, + "Did not return error when write failed", + err.Error(), + ) +} + +// Test the `Recv` method with multiple empty messages triggering an error +func TestStreamReader_TooManyEmptyMessages(t *testing.T) { + data := "\n\n\n\n\n\n" + resp := &http.Response{ + Body: io.NopCloser(bytes.NewBufferString(data)), + } + stream := streams.NewStreamReader[*groq.ChatCompletionStreamResponse]( + resp.Body, + map[string][]string{ + "Content-Type": {"text/event-stream"}, + }, + 5, + ) + + _, err := stream.Recv() + assert.ErrorIs(t, err, groqerr.ErrTooManyEmptyStreamMessages{}) +} + +// Test the `Close` method +func TestStreamReader_Close(t *testing.T) { + resp := &http.Response{ + Body: io.NopCloser(bytes.NewBufferString("")), + } + stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse]( + resp.Body, + map[string][]string{ + "Content-Type": {"text/event-stream"}, + }, + 5, + ) + + err := stream.Close() + assert.NoError(t, err) + +} + +func TestErrorAccumulatorBytes(t *testing.T) { + accumulator := &streams.DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } + + errBytes := accumulator.Bytes() + if len(errBytes) != 0 { + t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) + } + + err := accumulator.Write([]byte("{}")) + if err != nil { + t.Fatalf("%+v", err) + } + + errBytes = accumulator.Bytes() + if len(errBytes) == 0 { + t.Fatalf( + "Did not return error bytes when has error: %s", + string(errBytes), + ) + } +} + +func TestErrorByteWriteErrors(t *testing.T) { + accumulator := &streams.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + err := accumulator.Write([]byte("{")) + if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed{}) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} diff --git a/pkg/test/encoding.go b/pkg/test/encoding.go deleted file mode 100644 index a0bf80f..0000000 --- a/pkg/test/encoding.go +++ /dev/null @@ -1,27 +0,0 @@ -package test - -import ( - "encoding/json" - "io" -) - -func encode(v any) []byte { - res, err := json.Marshal(v) - if err != nil { - panic(err) - } - return res -} -func decode[T any](r io.Reader) T { - // make a new instance of the type - v := new(T) - // read the JSON data from the request body - bod, err := io.ReadAll(r) - if err != nil { - panic(err) - } - if err := json.Unmarshal(bod, &v); err != nil { - panic(err) - } - return *v -} diff --git a/pkg/test/helpers.go b/pkg/test/helpers.go index 290c07f..fc798d2 100644 --- a/pkg/test/helpers.go +++ b/pkg/test/helpers.go @@ -59,8 +59,8 @@ func (t *TokenRoundTripper) RoundTrip( return t.Fallback.RoundTrip(req) } -// IsUnitTest returns true if the unit test environment variable is set. -func IsUnitTest() bool { +// IsIntegrationTest returns true if the unit test environment variable is set. +func IsIntegrationTest() bool { return os.Getenv("UNIT") != "" } @@ -77,7 +77,7 @@ func GetAPIKey(key string) (string, error) { var DefaultLogger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ AddSource: true, Level: slog.LevelDebug, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { if a.Key == "time" { return slog.Attr{} } diff --git a/pkg/test/mod-jigsawstack.go b/pkg/test/mod-jigsawstack.go new file mode 100644 index 0000000..e4aa7f9 --- /dev/null +++ b/pkg/test/mod-jigsawstack.go @@ -0,0 +1,44 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" + "regexp" +) + +// JigsawStackTestServer Creates a mocked JigsawStack server which can pretend +// to handle requests during testing. +func (ts *ServerTest) JigsawStackTestServer() *httptest.Server { + return httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf( + "received a %s request at path %q\n", + r.Method, + r.URL.Path, + ) + + // check auth + if r.Header.Get("x-api-key") != GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } + } + http.Error( + w, + "the resource path doesn't exist", + http.StatusNotFound, + ) + }), + ) +} diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 127a1a7..ac183a6 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -1,14 +1,17 @@ package tools const ( - ToolTypeFunction ToolType = "function" // ToolTypeFunction is the function tool type. + // ToolTypeFunction is the function tool type. + ToolTypeFunction ToolType = "function" ) type ( // Tool represents the tool. Tool struct { - Type ToolType `json:"type"` // Type is the type of the tool. - Function FunctionDefinition `json:"function,omitempty"` // Function is the tool's functional definition. + // Type is the type of the tool. + Type ToolType `json:"type"` + // Function is the tool's functional definition. + Function FunctionDefinition `json:"function,omitempty"` } // ToolType is the tool type. // @@ -16,12 +19,15 @@ type ( ToolType string // ToolChoice represents the tool choice. ToolChoice struct { - Type ToolType `json:"type"` // Type is the type of the tool choice. - Function ToolFunction `json:"function,omitempty"` // Function is the function of the tool choice. + // Type is the type of the tool choice. + Type ToolType `json:"type"` + // Function is the function of the tool choice. + Function ToolFunction `json:"function,omitempty"` } // ToolFunction represents the tool function. ToolFunction struct { - Name string `json:"name"` // Name is the name of the tool function. + // Name is the name of the tool function. + Name string `json:"name"` } // FunctionDefinition represents the function definition. FunctionDefinition struct { @@ -44,14 +50,19 @@ type ( // ToolCall represents a tool call. ToolCall struct { // Index is not nil only in chat completion chunk object - Index *int `json:"index,omitempty"` // Index is the index of the tool call. - ID string `json:"id"` // ID is the id of the tool call. - Type string `json:"type"` // Type is the type of the tool call. - Function FunctionCall `json:"function"` // Function is the function of the tool call. + Index *int `json:"index,omitempty"` + // ID is the id of the tool call. + ID string `json:"id"` + // Type is the type of the tool call. + Type string `json:"type"` + // Function is the function of the tool call. + Function FunctionCall `json:"function"` } // FunctionCall represents a function call. FunctionCall struct { - Name string `json:"name,omitempty"` // Name is the name of the function call. - Arguments string `json:"arguments,omitempty"` // Arguments is the arguments of the function call in JSON format. + // Name is the name of the function call. + Name string `json:"name,omitempty"` + // Arguments is the arguments of the function call in JSON format. + Arguments string `json:"arguments,omitempty"` } ) diff --git a/scripts/generate-copyright-header/README.md b/scripts/generate-copyright-header/README.md new file mode 100644 index 0000000..301a101 --- /dev/null +++ b/scripts/generate-copyright-header/README.md @@ -0,0 +1,10 @@ +# generate-copyright-header + +This script will generate a copyright header for all golang files in the current directory. + +## Usage + +Run from the root of the project: +```bash +go run ./scripts/generate-copyright-header +``` diff --git a/scripts/generate-copyright-header/go.mod b/scripts/generate-copyright-header/go.mod new file mode 100644 index 0000000..0912048 --- /dev/null +++ b/scripts/generate-copyright-header/go.mod @@ -0,0 +1,3 @@ +module github.com/conneroisu/groq-go/generate-copyright-header + +go 1.23.2 diff --git a/scripts/generate-copyright-header/main.go b/scripts/generate-copyright-header/main.go new file mode 100644 index 0000000..f85b707 --- /dev/null +++ b/scripts/generate-copyright-header/main.go @@ -0,0 +1,8 @@ +// Package main is a script to generate the copyright header for each golang +// file in the project. +package main + +func main() { + // TODO: Implement copyright header generation + println("Hello, World!") +} diff --git a/scripts/generate-jigsaw-accents/README.md b/scripts/generate-jigsaw-accents/README.md new file mode 100644 index 0000000..f6558a6 --- /dev/null +++ b/scripts/generate-jigsaw-accents/README.md @@ -0,0 +1 @@ +# generate-jigsaw-accents diff --git a/scripts/generate-jigsaw-accents/accents.go.tmpl b/scripts/generate-jigsaw-accents/accents.go.tmpl new file mode 100644 index 0000000..dc0d3f7 --- /dev/null +++ b/scripts/generate-jigsaw-accents/accents.go.tmpl @@ -0,0 +1,10 @@ +package jigsawstack + +type Accent string + +const ( + {{ range $accent := .Accents }} + // Accent{{ $accent.GoName }} is the {{$accent.Gender}} accent from {{ $accent.LocaleName }}. + Accent{{ $accent.GoName }} Accent = "{{ $accent.Accent }}" + {{ end }} +) diff --git a/scripts/generate-jigsaw-accents/go.mod b/scripts/generate-jigsaw-accents/go.mod new file mode 100644 index 0000000..d9a6860 --- /dev/null +++ b/scripts/generate-jigsaw-accents/go.mod @@ -0,0 +1,3 @@ +module github.com/conneroisu/groq-go/generate-jigsaw-accents + +go 1.23.2 diff --git a/scripts/generate-jigsaw-accents/main.go b/scripts/generate-jigsaw-accents/main.go new file mode 100644 index 0000000..82b2fdb --- /dev/null +++ b/scripts/generate-jigsaw-accents/main.go @@ -0,0 +1,229 @@ +// Package main is a script to generate the jigsaw accents for the project. +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "regexp" + "strings" + "text/template" + "unicode" + + _ "embed" + + "github.com/conneroisu/groq-go/pkg/builders" + "github.com/conneroisu/groq-go/pkg/test" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +const ( + defaultBaseURL = "https://api.jigsawstack.com" + accentsEndpoint = "/v1/ai/tts" +) + +type ( + // Client is a JigsawStack extension. + Client struct { + baseURL string + apiKey string + client *http.Client + logger *slog.Logger + header builders.Header + } + // SpeakerVoiceAccent represents a speaker voice accent. + SpeakerVoiceAccent struct { + Success bool `json:"success"` + Message string `json:"message"` + Accents []struct { + GoName string `json:"go_name"` + Accent string `json:"accent"` + LocaleName string `json:"locale_name"` + Gender string `json:"gender"` + } `json:"accents"` + } +) + +// AudioGetSpeakerVoiceAccents gets the speaker voice accents. +// +// GET https://api.jigsawstack.com/v1/ai/tts +// +// https://docs.jigsawstack.com/api-reference/audio/speaker-voice-accents +func (c *Client) AudioGetSpeakerVoiceAccents( + ctx context.Context, +) (response SpeakerVoiceAccent, err error) { + uri := c.baseURL + accentsEndpoint + req, err := builders.NewRequest( + ctx, + c.header, + http.MethodGet, + uri, + ) + if err != nil { + return + } + var resp SpeakerVoiceAccent + err = c.sendRequest(req, &resp) + if err != nil { + return + } + if !resp.Success { + return resp, fmt.Errorf("failed to get accents: %v", resp.Message) + } + for i := range resp.Accents { + accent := &resp.Accents[i] + accent.GoName = PascalCase(strings.ReplaceAll(accent.Accent, "-", "")) + } + return resp, nil +} + +func main() { + ctx := context.Background() + if err := run(ctx); err != nil { + fmt.Println(err) + fmt.Println(err.Error()) + os.Exit(1) + } +} + +func newClient(apiKey string) *Client { + return &Client{ + apiKey: apiKey, + header: builders.Header{SetCommonHeaders: func(r *http.Request) { + r.Header.Set("x-api-key", apiKey) + }}, + client: http.DefaultClient, + baseURL: defaultBaseURL, + } +} + +func run(ctx context.Context) error { + println("generating accents") + key, err := test.GetAPIKey("JIGSAWSTACK_API_KEY") + if err != nil { + return err + } + client := newClient(key) + accents, err := client.AudioGetSpeakerVoiceAccents(ctx) + if err != nil { + return err + } + output := FillAccents(accents) + println(output) + return nil +} + +func (c *Client) sendRequest(req *http.Request, v any) error { + req.Header.Set("Accept", "application/json") + contentType := req.Header.Get("Content-Type") + if contentType == "" { + req.Header.Set("Content-Type", "application/json") + } + res, err := c.client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode < http.StatusOK || + res.StatusCode >= http.StatusBadRequest { + return nil + } + if v == nil { + return nil + } + switch o := v.(type) { + case *string: + b, err := io.ReadAll(res.Body) + if err != nil { + return err + } + *o = string(b) + return nil + default: + err = json.NewDecoder(res.Body).Decode(v) + if err != nil { + read, err := io.ReadAll(res.Body) + if err != nil { + return err + } + c.logger.Debug("failed to decode response", "response", string(read)) + return fmt.Errorf("failed to decode response: %w\nbody: %s", err, string(read)) + } + return nil + } +} + +var ( + // LowerCaseLettersCharset is a set of lower case letters. + LowerCaseLettersCharset = []rune("abcdefghijklmnopqrstuvwxyz") + // UpperCaseLettersCharset is a set of upper case letters. + UpperCaseLettersCharset = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + // LettersCharset is a set of letters. + LettersCharset = append(LowerCaseLettersCharset, UpperCaseLettersCharset...) + // NumbersCharset is a set of numbers. + NumbersCharset = []rune("0123456789") + // AlphanumericCharset is a set of alphanumeric characters. + AlphanumericCharset = append(LettersCharset, NumbersCharset...) + // SpecialCharset is a set of special characters. + SpecialCharset = []rune("!@#$%^&*()_+-=[]{}|;':\",./<>?") + // AllCharset is a set of all characters. + AllCharset = append(AlphanumericCharset, SpecialCharset...) + + // bearer:disable go_lang_permissive_regex_validation + splitWordReg = regexp.MustCompile(`([a-z])([A-Z0-9])|([a-zA-Z])([0-9])|([0-9])([a-zA-Z])|([A-Z])([A-Z])([a-z])`) + // bearer:disable go_lang_permissive_regex_validation + splitNumberLetterReg = regexp.MustCompile(`([0-9])([a-zA-Z])`) +) + +// Words splits string into an array of its words. +func Words(str string) []string { + str = splitWordReg.ReplaceAllString(str, `$1$3$5$7 $2$4$6$8$9`) + // example: Int8Value => Int 8Value => Int 8 Value + str = splitNumberLetterReg.ReplaceAllString(str, "$1 $2") + var result strings.Builder + for _, r := range str { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + result.WriteRune(r) + } else { + result.WriteRune(' ') + } + } + return strings.Fields(result.String()) +} + +// Capitalize converts the first character of string to upper case and the remaining to lower case. +func Capitalize(str string) string { + return cases.Title(language.English).String(str) +} + +// PascalCase converts string to pascal case. +func PascalCase(str string) string { + items := Words(str) + for i := range items { + items[i] = Capitalize(items[i]) + } + return strings.Join(items, "") +} + +//go:embed accents.go.tmpl +var accentsTemplate string + +var ( + textTemplate = template.Must(template.New("accents").Parse(accentsTemplate)) +) + +// FillAccents fills the accents template with the given accents +func FillAccents(accents SpeakerVoiceAccent) string { + var buf bytes.Buffer + err := textTemplate.Execute(&buf, accents) + if err != nil { + panic(err) + } + return buf.String() +} diff --git a/scripts/generate-models/go.mod b/scripts/generate-models/go.mod index 004f126..0df1709 100644 --- a/scripts/generate-models/go.mod +++ b/scripts/generate-models/go.mod @@ -2,6 +2,4 @@ module github.com/conneroisu/groq-go/cmd/models go 1.23.2 -require github.com/samber/lo v1.47.0 - -require golang.org/x/text v0.18.0 // indirect +require golang.org/x/text v0.18.0 diff --git a/scripts/generate-models/go.sum b/scripts/generate-models/go.sum index 95a07ff..d4877b0 100644 --- a/scripts/generate-models/go.sum +++ b/scripts/generate-models/go.sum @@ -1,4 +1,2 @@ -github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= -github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= diff --git a/scripts/generate-models/main.go b/scripts/generate-models/main.go index 3aefabd..fe03827 100644 --- a/scripts/generate-models/main.go +++ b/scripts/generate-models/main.go @@ -14,16 +14,20 @@ import ( "log" "net/http" "os" + "regexp" "sort" + "strings" "text/template" "time" + "unicode" - "github.com/samber/lo" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) const ( - modelFileName = "models.go" - modelTestFileName = "models_test.go" + modelFileName = "./pkg/models/models.go" + modelTestFileName = "./pkg/models/models_test.go" ) var ( @@ -50,6 +54,81 @@ func main() { } } +// run runs the main function. +func run(ctx context.Context) error { + client := &http.Client{} + req, err := http.NewRequestWithContext( + ctx, + "GET", + "https://api.groq.com/openai/v1/models", + nil, + ) + if err != nil { + return err + } + key := os.Getenv("GROQ_KEY") + if key == "" { + return fmt.Errorf("GROQ_KEY is not set") + } + req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + bodyText, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + var response Response + err = json.Unmarshal(bodyText, &response) + if err != nil { + return err + } + buf := new(bytes.Buffer) + ms, err := response.Categorize() + if err != nil { + return err + } + err = fillModelsTemplate(buf, ms) + if err != nil { + return err + } + formatted, err := cleanFile(buf) + if err != nil { + return err + } + f, err := os.Create(modelFileName) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(formatted) + if err != nil { + return err + } + buf.Reset() + err = fillTestTemplate(buf, ms) + if err != nil { + return err + } + formatted, err = cleanFile(buf) + if err != nil { + return err + } + f, err = os.Create(modelTestFileName) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(formatted) + if err != nil { + return err + } + return nil +} + // Response is a response from the models endpoint. type Response struct { Object string `json:"object"` @@ -102,7 +181,7 @@ func (r *Response) Categorize() (CategorizedModels, error) { return models, nil } -func isMultiModalModel(model ResponseModel) bool { +func isMultiModalModel(_ ResponseModel) bool { return false } @@ -123,81 +202,7 @@ func isTranslationModel(model ResponseModel) bool { } func isTranscriptionModel(model ResponseModel) bool { - return model.ID == "whisper-large-v3" -} - -// run runs the main function. -func run(_ context.Context) error { - client := &http.Client{} - req, err := http.NewRequest( - "GET", - "https://api.groq.com/openai/v1/models", - nil, - ) - if err != nil { - return err - } - key := os.Getenv("GROQ_KEY") - if key == "" { - return fmt.Errorf("GROQ_KEY is not set") - } - req.Header.Set("Authorization", "Bearer "+key) - req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - bodyText, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - var response Response - err = json.Unmarshal(bodyText, &response) - if err != nil { - return err - } - buf := new(bytes.Buffer) - ms, err := response.Categorize() - if err != nil { - return err - } - err = fillModelsTemplate(buf, ms) - if err != nil { - return err - } - formatted, err := cleanFile(buf) - if err != nil { - return err - } - f, err := os.Create(modelFileName) - if err != nil { - return err - } - defer f.Close() - _, err = f.Write(formatted) - if err != nil { - return err - } - buf.Reset() - err = fillTestTemplate(buf, ms) - if err != nil { - return err - } - formatted, err = cleanFile(buf) - if err != nil { - return err - } - f, err = os.Create(modelTestFileName) - if err != nil { - return err - } - defer f.Close() - _, err = f.Write(formatted) - if err != nil { - return err - } - return nil + return strings.Contains(model.ID, "whisper") } func cleanFile(r io.Reader) ([]byte, error) { @@ -216,22 +221,62 @@ func cleanFile(r io.Reader) ([]byte, error) { return formatted, nil } -func fillModelsTemplate(w io.Writer, models CategorizedModels) (err error) { - modelTemplate, err = modelTemplate.Parse(modelFileTemplate) - if err != nil { - return err +var ( + // LowerCaseLettersCharset is a set of lower case letters. + LowerCaseLettersCharset = []rune("abcdefghijklmnopqrstuvwxyz") + // UpperCaseLettersCharset is a set of upper case letters. + UpperCaseLettersCharset = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + // LettersCharset is a set of letters. + LettersCharset = append(LowerCaseLettersCharset, UpperCaseLettersCharset...) + // NumbersCharset is a set of numbers. + NumbersCharset = []rune("0123456789") + // AlphanumericCharset is a set of alphanumeric characters. + AlphanumericCharset = append(LettersCharset, NumbersCharset...) + // SpecialCharset is a set of special characters. + SpecialCharset = []rune("!@#$%^&*()_+-=[]{}|;':\",./<>?") + // AllCharset is a set of all characters. + AllCharset = append(AlphanumericCharset, SpecialCharset...) + + // bearer:disable go_lang_permissive_regex_validation + splitWordReg = regexp.MustCompile(`([a-z])([A-Z0-9])|([a-zA-Z])([0-9])|([0-9])([a-zA-Z])|([A-Z])([A-Z])([a-z])`) + // bearer:disable go_lang_permissive_regex_validation + splitNumberLetterReg = regexp.MustCompile(`([0-9])([a-zA-Z])`) +) + +// Words splits string into an array of its words. +func Words(str string) []string { + str = splitWordReg.ReplaceAllString(str, `$1$3$5$7 $2$4$6$8$9`) + // example: Int8Value => Int 8Value => Int 8 Value + str = splitNumberLetterReg.ReplaceAllString(str, "$1 $2") + var result strings.Builder + for _, r := range str { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + result.WriteRune(r) + } else { + result.WriteRune(' ') + } } - err = modelTemplate.Execute(w, models) - if err != nil { - return err + return strings.Fields(result.String()) +} + +// Capitalize converts the first character of string to upper case and the remaining to lower case. +func Capitalize(str string) string { + return cases.Title(language.English).String(str) +} + +// PascalCase converts string to pascal case. +func PascalCase(str string) string { + items := Words(str) + for i := range items { + items[i] = Capitalize(items[i]) } - return nil + return strings.Join(items, "") } func nameModels(models []ResponseModel) { for i := range models { if (models)[i].Name == "" { - models[i].Name = lo.PascalCase(models[i].ID) + models[i].Name = PascalCase(models[i].ID) } } // sort models by name alphabetically @@ -240,6 +285,17 @@ func nameModels(models []ResponseModel) { }) } +func fillModelsTemplate(w io.Writer, models CategorizedModels) (err error) { + modelTemplate, err = modelTemplate.Parse(modelFileTemplate) + if err != nil { + return err + } + err = modelTemplate.Execute(w, models) + if err != nil { + return err + } + return nil +} func fillTestTemplate(w io.Writer, models CategorizedModels) (err error) { testTemplate, err = testTemplate.Parse(testFileTemplate) if err != nil { diff --git a/scripts/generate-models/models.go.tmpl b/scripts/generate-models/models.go.tmpl index 63f69f9..700a923 100644 --- a/scripts/generate-models/models.go.tmpl +++ b/scripts/generate-models/models.go.tmpl @@ -3,31 +3,20 @@ // Created at: {{ getCurrentDate }} // // groq-modeler Version 1.1.2 -package groq +package models type ( - model string - - // Endpoint is the endpoint for the groq api. - // string - Endpoint string + // Model is a ai model accessible through the groq api. + Model string // ChatModel is the type for chat models present on the groq api. - ChatModel model + ChatModel Model // ModerationModel is the type for moderation models present on the groq api. - ModerationModel model + ModerationModel Model // AudioModel is the type for audio models present on the groq api. - AudioModel model -) - -const ( - chatCompletionsSuffix Endpoint = "/chat/completions" - transcriptionsSuffix Endpoint = "/audio/transcriptions" - translationsSuffix Endpoint = "/audio/translations" - embeddingsSuffix Endpoint = "/embeddings" - moderationsSuffix Endpoint = "/moderations" + AudioModel Model ) var ( diff --git a/scripts/generate-models/models_test.go.tmpl b/scripts/generate-models/models_test.go.tmpl index 4c0b437..8d3c5d7 100644 --- a/scripts/generate-models/models_test.go.tmpl +++ b/scripts/generate-models/models_test.go.tmpl @@ -3,17 +3,27 @@ // Created at: {{ getCurrentDate }} // // groq-modeler Version 1.1.2 -package groq +package models_test import ( + "bytes" "context" "os" "testing" "time" + "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/moderation" + "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" + + _ "embed" ) +//go:embed testdata/whisper.mp3 +var whisperBytes []byte + {{- range $model := .ChatModels }} // TestChatModels{{ $model.Name }} tests the {{ $model.Name }} model. // @@ -25,13 +35,15 @@ func TestChatModels{{ $model.Name }}(t *testing.T) { } a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateChatCompletion(ctx, ChatCompletionRequest{ - Model: Model{{ $model.Name }}, - Messages: []ChatCompletionMessage{ + response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ + Model: models.Model{{ $model.Name }}, + Messages: []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "What is a proface display?", }, }, @@ -53,16 +65,20 @@ func TestChatModels{{ $model.Name }}(t *testing.T) { // model. func Test{{ $model.Name }}(t *testing.T) { if len(os.Getenv("UNIT")) < 1 { - t.Skip("Skipping {{ $model.Name }} test") + t.Skip("Skipping {{ $model.Name }} transcription test") } time.Sleep(time.Second * 5) a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.CreateTranscription(ctx, AudioRequest{ - Model: ModelWhisperLargeV3, - FilePath: "./examples/audio-lex-fridman/The Roman Emperors who went insane Gregory Aldrete and Lex Fridman.mp3", + reader := bytes.NewReader(whisperBytes) + response, err := client.CreateTranscription(ctx, groq.AudioRequest{ + Model: models.ModelWhisperLargeV3, + Reader: reader, + FilePath: "whisper.mp3", }) a.NoError(err, "CreateTranscription error") a.NotEmpty(response.Text, "response.Text is empty for model WhisperLargeV3 calling CreateTranscription") @@ -76,27 +92,29 @@ func Test{{ $model.Name }}(t *testing.T) { // and the operations are working as expected for the specific model type. func Test{{ $model.Name }}(t *testing.T) { if len(os.Getenv("UNIT")) < 1 { - t.Skip("Skipping {{ $model.Name }} test") + t.Skip("Skipping {{ $model.Name }} moderation test") } time.Sleep(time.Second * 5) a := assert.New(t) ctx := context.Background() - client, err := NewClient(os.Getenv("GROQ_KEY")) + apiKey, err := test.GetAPIKey("GROQ_KEY") + a.NoError(err, "GetAPIKey error") + client, err := groq.NewClient(apiKey) a.NoError(err, "NewClient error") - response, err := client.Moderate(ctx, ModerationRequest{ - Model: Model{{ $model.Name }}, - Messages: []ChatCompletionMessage{ + response, err := client.Moderate(ctx, + []groq.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: groq.ChatMessageRoleUser, Content: "I want to kill them.", }, }, - }) + models.Model{{ $model.Name }}, + ) a.NoError(err, "Moderation error") a.Equal(true, response.Flagged) a.Contains( response.Categories, - CategoryViolentCrimes, + moderation.CategoryViolentCrimes, ) } {{- end }} diff --git a/testdata/.gitignore b/testdata/.gitignore deleted file mode 100644 index 5d3b7b7..0000000 --- a/testdata/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.out.json \ No newline at end of file diff --git a/testdata/allow_additional_props.json b/testdata/allow_additional_props.json deleted file mode 100644 index 32046bb..0000000 --- a/testdata/allow_additional_props.json +++ /dev/null @@ -1,229 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - } - }, - "type": "object", - "required": [ - "family_name" - ] - }, - "MapType": { - "type": "object" - }, - "TestUser": { - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] - } - } -} diff --git a/testdata/anyof.json b/testdata/anyof.json deleted file mode 100644 index f4f0f38..0000000 --- a/testdata/anyof.json +++ /dev/null @@ -1,91 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/root-any-of", - "$ref": "#/$defs/RootAnyOf", - "$defs": { - "ChildAnyOf": { - "anyOf": [ - { - "required": [ - "child1", - "child4" - ], - "title": "group1" - }, - { - "required": [ - "child2", - "child3" - ], - "title": "group2" - } - ], - "properties": { - "child1": { - "type": "string" - }, - "child2": { - "type": "string" - }, - "child3": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array" - } - ] - }, - "child4": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object" - }, - "RootAnyOf": { - "anyOf": [ - { - "required": [ - "field1", - "field4" - ], - "title": "group1" - }, - { - "required": [ - "field2" - ], - "title": "group2" - } - ], - "properties": { - "field1": { - "type": "string" - }, - "field2": { - "type": "string" - }, - "field3": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "array" - } - ] - }, - "field4": { - "type": "string" - }, - "child": { - "$ref": "#/$defs/ChildAnyOf" - } - }, - "additionalProperties": false, - "type": "object" - } - } -} \ No newline at end of file diff --git a/testdata/array_handling.json b/testdata/array_handling.json deleted file mode 100644 index 6cc546a..0000000 --- a/testdata/array_handling.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/array-handler", - "$ref": "#/$defs/ArrayHandler", - "$defs": { - "ArrayHandler": { - "properties": { - "min_len": { - "items": { - "type": "string", - "minLength": 2 - }, - "type": "array", - "default": [ - "qwerty" - ] - }, - "min_val": { - "items": { - "type": "number", - "minimum": 2.5 - }, - "type": "array" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "min_len", - "min_val" - ] - } - } -} \ No newline at end of file diff --git a/testdata/array_type.json b/testdata/array_type.json deleted file mode 100644 index c1262f8..0000000 --- a/testdata/array_type.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/array-type", - "$ref": "#/$defs/ArrayType", - "$defs": { - "ArrayType": { - "items": { - "type": "string" - }, - "type": "array" - } - } -} \ No newline at end of file diff --git a/testdata/base_schema_id.json b/testdata/base_schema_id.json deleted file mode 100644 index 7c52f31..0000000 --- a/testdata/base_schema_id.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/schemas/lookup-user", - "$ref": "#/$defs/LookupUser", - "$defs": { - "LookupName": { - "properties": { - "first": { - "type": "string" - }, - "surname": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "first", - "surname" - ] - }, - "LookupUser": { - "properties": { - "name": { - "$ref": "#/$defs/LookupName" - }, - "alias": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name" - ] - } - } -} \ No newline at end of file diff --git a/testdata/commas_in_pattern.json b/testdata/commas_in_pattern.json deleted file mode 100644 index 42b4385..0000000 --- a/testdata/commas_in_pattern.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/pattern-test", - "$ref": "#/$defs/PatternTest", - "$defs": { - "PatternTest": { - "properties": { - "with_pattern": { - "type": "string", - "maxLength": 50, - "minLength": 1, - "pattern": "[0-9]{1,4}" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "with_pattern" - ] - } - } -} \ No newline at end of file diff --git a/testdata/compact_date.json b/testdata/compact_date.json deleted file mode 100644 index 48dcece..0000000 --- a/testdata/compact_date.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/compact-date", - "$ref": "#/$defs/CompactDate", - "$defs": { - "CompactDate": { - "type": "string", - "pattern": "^[0-9]{4}-[0-1][0-9]$", - "title": "Compact Date", - "description": "Short date that only includes year and month" - } - } -} \ No newline at end of file diff --git a/testdata/custom_additional.json b/testdata/custom_additional.json deleted file mode 100644 index becfcd6..0000000 --- a/testdata/custom_additional.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/grandfather-type", - "$ref": "#/$defs/GrandfatherType", - "$defs": { - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - }, - "ip_addr": { - "type": "string", - "format": "ipv4" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name", - "ip_addr" - ] - } - } -} \ No newline at end of file diff --git a/testdata/custom_base_schema_id.json b/testdata/custom_base_schema_id.json deleted file mode 100644 index 479abc6..0000000 --- a/testdata/custom_base_schema_id.json +++ /dev/null @@ -1,207 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/schema/TestUser", - "$ref": "#/$defs/TestUser", - "$defs": { - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "TestUser": { - "properties": { - "some_base_property": { - "type": "integer" - }, - "some_base_property_yaml": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "id": { - "type": "integer" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "patternProperties": { - ".*": { - "type": "string" - } - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "type": "string", - "contentEncoding": "base64" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1, - 1.5, - 2 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "some_base_property", - "some_base_property_yaml", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "id", - "name", - "password", - "TestFlag", - "age", - "email", - "Baz", - "color", - "roles", - "raw" - ] - } - } -} \ No newline at end of file diff --git a/testdata/custom_map_type.json b/testdata/custom_map_type.json deleted file mode 100644 index be28ec7..0000000 --- a/testdata/custom_map_type.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/custom-map-outer", - "$ref": "#/$defs/CustomMapOuter", - "$defs": { - "CustomMapOuter": { - "properties": { - "my_map": { - "$ref": "#/$defs/CustomMapType" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "my_map" - ] - }, - "CustomMapType": { - "items": { - "properties": { - "key": { - "type": "string" - }, - "value": { - "type": "string" - } - }, - "type": "object", - "required": [ - "key", - "value" - ] - }, - "type": "array" - } - } -} \ No newline at end of file diff --git a/testdata/custom_slice_type.json b/testdata/custom_slice_type.json deleted file mode 100644 index 580aae9..0000000 --- a/testdata/custom_slice_type.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/custom-slice-outer", - "$ref": "#/$defs/CustomSliceOuter", - "$defs": { - "CustomSliceOuter": { - "properties": { - "slice": { - "$ref": "#/$defs/CustomSliceType" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "slice" - ] - }, - "CustomSliceType": { - "oneOf": [ - { - "type": "string" - }, - { - "items": { - "type": "string" - }, - "type": "array" - } - ] - } - } -} \ No newline at end of file diff --git a/testdata/custom_type.json b/testdata/custom_type.json deleted file mode 100644 index 2eab5a8..0000000 --- a/testdata/custom_type.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/custom-type-field", - "$ref": "#/$defs/CustomTypeField", - "$defs": { - "CustomTypeField": { - "properties": { - "CreatedAt": { - "type": "string", - "format": "date-time" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "CreatedAt" - ] - } - } -} \ No newline at end of file diff --git a/testdata/custom_type_extend.json b/testdata/custom_type_extend.json deleted file mode 100644 index 5ddd6fe..0000000 --- a/testdata/custom_type_extend.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/schema-extend-test", - "$ref": "#/$defs/SchemaExtendTest", - "$defs": { - "SchemaExtendTest": { - "properties": { - "LastName": { - "type": "string", - "description": "some extra words" - }, - "middle_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "LastName" - ] - } - } -} \ No newline at end of file diff --git a/testdata/custom_type_with_interface.json b/testdata/custom_type_with_interface.json deleted file mode 100644 index 79fa620..0000000 --- a/testdata/custom_type_with_interface.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/custom-type-field-with-interface", - "$ref": "#/$defs/CustomTypeFieldWithInterface", - "$defs": { - "CustomTimeWithInterface": { - "type": "string", - "format": "date-time" - }, - "CustomTypeFieldWithInterface": { - "properties": { - "CreatedAt": { - "$ref": "#/$defs/CustomTimeWithInterface" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "CreatedAt" - ] - } - } -} \ No newline at end of file diff --git a/testdata/defaults_expanded_toplevel.json b/testdata/defaults_expanded_toplevel.json deleted file mode 100644 index 0cdce35..0000000 --- a/testdata/defaults_expanded_toplevel.json +++ /dev/null @@ -1,228 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "MapType": { - "type": "object" - } - }, - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] -} \ No newline at end of file diff --git a/testdata/equals_in_pattern.json b/testdata/equals_in_pattern.json deleted file mode 100644 index 26fe10b..0000000 --- a/testdata/equals_in_pattern.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/pattern-equals-test", - "$ref": "#/$defs/PatternEqualsTest", - "$defs": { - "PatternEqualsTest": { - "properties": { - "WithEquals": { - "type": "string", - "pattern": "foo=bar" - }, - "WithEqualsAndCommas": { - "type": "string", - "pattern": "foo,=bar" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "WithEquals", - "WithEqualsAndCommas" - ] - } - } -} \ No newline at end of file diff --git a/testdata/go_comments.json b/testdata/go_comments.json deleted file mode 100644 index 8b1dcc9..0000000 --- a/testdata/go_comments.json +++ /dev/null @@ -1,113 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/examples/user", - "$ref": "#/$defs/User", - "$defs": { - "NamedPets": { - "additionalProperties": { - "$ref": "#/$defs/Pet" - }, - "type": "object", - "description": "NamedPets is a map of animal names to pets." - }, - "Pet": { - "properties": { - "name": { - "type": "string", - "title": "Name", - "description": "Name of the animal." - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name" - ], - "description": "Pet defines the user's fury friend." - }, - "Pets": { - "items": { - "$ref": "#/$defs/Pet" - }, - "type": "array", - "description": "Pets is a collection of Pet objects." - }, - "Plant": { - "properties": { - "variant": { - "type": "string", - "title": "Variant", - "description": "This comment will be used" - }, - "multicellular": { - "type": "boolean", - "title": "Multicellular", - "description": "Multicellular is true if the plant is multicellular" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "variant" - ], - "description": "Plant represents the plants the user might have and serves as a test of structs inside a `type` set." - }, - "User": { - "properties": { - "id": { - "type": "integer", - "description": "Unique sequential identifier." - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "examples": [ - "joe", - "lucy" - ] - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "type": "object" - }, - "pets": { - "$ref": "#/$defs/Pets", - "description": "An array of pets the user cares for." - }, - "named_pets": { - "$ref": "#/$defs/NamedPets", - "description": "Set of animal names to pets" - }, - "plants": { - "items": { - "$ref": "#/$defs/Plant" - }, - "type": "array", - "title": "Plants", - "description": "Set of plants that the user likes" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "name", - "pets", - "named_pets", - "plants" - ], - "description": "User is used as a base to provide tests for comments." - } - } -} \ No newline at end of file diff --git a/testdata/ignore_type.json b/testdata/ignore_type.json deleted file mode 100644 index d310707..0000000 --- a/testdata/ignore_type.json +++ /dev/null @@ -1,224 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "properties": {}, - "additionalProperties": false, - "type": "object" - }, - "MapType": { - "type": "object" - }, - "TestUser": { - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] - } - } -} \ No newline at end of file diff --git a/testdata/inlining_embedded.json b/testdata/inlining_embedded.json deleted file mode 100644 index e4acbf6..0000000 --- a/testdata/inlining_embedded.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/outer-named", - "$defs": { - "Inner": { - "properties": { - "Foo": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "Foo" - ] - } - }, - "properties": { - "text": { - "type": "string" - }, - "inner": { - "$ref": "#/$defs/Inner" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "inner" - ] -} \ No newline at end of file diff --git a/testdata/inlining_embedded_anchored.json b/testdata/inlining_embedded_anchored.json deleted file mode 100644 index 20abb6c..0000000 --- a/testdata/inlining_embedded_anchored.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/outer-named", - "$anchor": "OuterNamed", - "$defs": { - "Inner": { - "$anchor": "Inner", - "properties": { - "Foo": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "Foo" - ] - } - }, - "properties": { - "text": { - "type": "string" - }, - "inner": { - "$ref": "#/$defs/Inner" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "inner" - ] -} \ No newline at end of file diff --git a/testdata/inlining_inheritance.json b/testdata/inlining_inheritance.json deleted file mode 100644 index 043bf82..0000000 --- a/testdata/inlining_inheritance.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/outer", - "properties": { - "TextNamed": { - "type": "string" - }, - "Text": { - "type": "string" - }, - "Foo": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "TextNamed", - "Foo" - ] -} \ No newline at end of file diff --git a/testdata/inlining_ptr.json b/testdata/inlining_ptr.json deleted file mode 100644 index 793142d..0000000 --- a/testdata/inlining_ptr.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/outer-ptr", - "properties": { - "Foo": { - "type": "string" - }, - "Text": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "Foo" - ] -} \ No newline at end of file diff --git a/testdata/inlining_tag.json b/testdata/inlining_tag.json deleted file mode 100644 index 97ab676..0000000 --- a/testdata/inlining_tag.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/outer-inlined", - "properties": { - "text": { - "type": "string" - }, - "Foo": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "Foo" - ] -} \ No newline at end of file diff --git a/testdata/keynamed.json b/testdata/keynamed.json deleted file mode 100644 index 4f5ffb3..0000000 --- a/testdata/keynamed.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/key-named", - "$ref": "#/$defs/KeyNamed", - "$defs": { - "KeyNamed": { - "properties": { - "ThisWasLeftAsIs": { - "type": "string" - }, - "coming_from_json_tag": { - "type": "boolean" - }, - "nested_not_renamed": { - "$ref": "#/$defs/KeyNamedNested" - }, - "✨unicode✨ s̸̥͝h̷̳͒e̴̜̽n̸̡̿a̷̘̔n̷̘͐i̶̫̐ǵ̶̯a̵̘͒n̷̮̾s̸̟̓": { - "type": "string" - }, - "20.01": { - "type": "integer", - "description": "Description was preserved" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "ThisWasLeftAsIs", - "coming_from_json_tag", - "nested_not_renamed", - "✨unicode✨ s̸̥͝h̷̳͒e̴̜̽n̸̡̿a̷̘̔n̷̘͐i̶̫̐ǵ̶̯a̵̘͒n̷̮̾s̸̟̓", - "20.01" - ] - }, - "KeyNamedNested": { - "properties": { - "nested-renamed-property": { - "type": "string" - }, - "NotRenamed": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "nested-renamed-property", - "NotRenamed" - ] - } - } -} \ No newline at end of file diff --git a/testdata/lookup.json b/testdata/lookup.json deleted file mode 100644 index 1d52012..0000000 --- a/testdata/lookup.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/schemas/lookup-user", - "$ref": "#/$defs/LookupUser", - "$defs": { - "LookupUser": { - "properties": { - "name": { - "$ref": "https://example.com/schemas/lookup-name" - }, - "alias": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name" - ] - } - } -} \ No newline at end of file diff --git a/testdata/lookup_expanded.json b/testdata/lookup_expanded.json deleted file mode 100644 index a9013da..0000000 --- a/testdata/lookup_expanded.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/schemas/lookup-user", - "$anchor": "LookupUser", - "properties": { - "name": { - "$ref": "https://example.com/schemas/lookup-name" - }, - "alias": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name" - ] -} \ No newline at end of file diff --git a/testdata/map_type.json b/testdata/map_type.json deleted file mode 100644 index 2ce2e5f..0000000 --- a/testdata/map_type.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/map-type", - "$ref": "#/$defs/MapType", - "$defs": { - "MapType": { - "type": "object" - } - } -} \ No newline at end of file diff --git a/testdata/no_reference.json b/testdata/no_reference.json deleted file mode 100644 index 81e1236..0000000 --- a/testdata/no_reference.json +++ /dev/null @@ -1,217 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "type": "object" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "type": "string", - "contentEncoding": "base64" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] -} \ No newline at end of file diff --git a/testdata/no_reference_anchor.json b/testdata/no_reference_anchor.json deleted file mode 100644 index 454e1c4..0000000 --- a/testdata/no_reference_anchor.json +++ /dev/null @@ -1,219 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$anchor": "TestUser", - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$anchor": "GrandfatherType", - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "type": "object" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "type": "string", - "contentEncoding": "base64" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] -} \ No newline at end of file diff --git a/testdata/nullable.json b/testdata/nullable.json deleted file mode 100644 index ecc0bf5..0000000 --- a/testdata/nullable.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-nullable", - "$ref": "#/$defs/TestNullable", - "$defs": { - "TestNullable": { - "properties": { - "child1": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ] - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "child1" - ] - } - } -} \ No newline at end of file diff --git a/testdata/number_handling.json b/testdata/number_handling.json deleted file mode 100644 index 4c1bfef..0000000 --- a/testdata/number_handling.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/number-handler", - "$ref": "#/$defs/NumberHandler", - "$defs": { - "NumberHandler": { - "properties": { - "int64": { - "type": "integer", - "default": 12 - }, - "float32": { - "type": "number", - "default": 12.5 - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "int64", - "float32" - ] - } - } -} \ No newline at end of file diff --git a/testdata/oneof.json b/testdata/oneof.json deleted file mode 100644 index 4523fe4..0000000 --- a/testdata/oneof.json +++ /dev/null @@ -1,104 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/root-one-of", - "$ref": "#/$defs/RootOneOf", - "$defs": { - "ChildOneOf": { - "oneOf": [ - { - "required": [ - "child1", - "child4" - ], - "title": "group1" - }, - { - "required": [ - "child2", - "child3" - ], - "title": "group2" - } - ], - "properties": { - "child1": { - "type": "string" - }, - "child2": { - "type": "string" - }, - "child3": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array" - } - ] - }, - "child4": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object" - }, - "RootOneOf": { - "oneOf": [ - { - "required": [ - "field1", - "field4" - ], - "title": "group1" - }, - { - "required": [ - "field2" - ], - "title": "group2" - } - ], - "properties": { - "field1": { - "type": "string" - }, - "field2": { - "type": "string" - }, - "field3": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array" - } - ] - }, - "field4": { - "type": "string" - }, - "child": { - "$ref": "#/$defs/ChildOneOf" - }, - "field6": { - "oneOf": [ - { - "$ref": "Outer" - }, - { - "$ref": "OuterNamed" - }, - { - "$ref": "OuterPtr" - } - ] - } - }, - "additionalProperties": false, - "type": "object" - } - } -} \ No newline at end of file diff --git a/testdata/oneof_ref.json b/testdata/oneof_ref.json deleted file mode 100644 index 84f98db..0000000 --- a/testdata/oneof_ref.json +++ /dev/null @@ -1,59 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/server", - "$ref": "#/$defs/Server", - "$defs": { - "Server": { - "properties": { - "ip_address": { - "oneOf": [ - { - "$ref": "#/$defs/ipv4" - }, - { - "$ref": "#/$defs/ipv6" - } - ] - }, - "ip_addresses": { - "items": { - "oneOf": [ - { - "$ref": "#/$defs/ipv4" - }, - { - "$ref": "#/$defs/ipv6" - } - ] - }, - "type": "array" - }, - "ip_address_any": { - "anyOf": [ - { - "$ref": "#/$defs/ipv4" - }, - { - "$ref": "#/$defs/ipv6" - } - ] - }, - "ip_addresses_any": { - "items": { - "anyOf": [ - { - "$ref": "#/$defs/ipv4" - }, - { - "$ref": "#/$defs/ipv6" - } - ] - }, - "type": "array" - } - }, - "additionalProperties": false, - "type": "object" - } - } -} \ No newline at end of file diff --git a/testdata/recursive.json b/testdata/recursive.json deleted file mode 100644 index d5574e6..0000000 --- a/testdata/recursive.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/recursive-example", - "$ref": "#/$defs/RecursiveExample", - "$defs": { - "RecursiveExample": { - "properties": { - "text": { - "type": "string" - }, - "children": { - "items": { - "$ref": "#/$defs/RecursiveExample" - }, - "type": "array" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "text" - ] - } - } -} \ No newline at end of file diff --git a/testdata/required_from_jsontags.json b/testdata/required_from_jsontags.json deleted file mode 100644 index 709a582..0000000 --- a/testdata/required_from_jsontags.json +++ /dev/null @@ -1,218 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "MapType": { - "type": "object" - }, - "TestUser": { - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "SomeUntaggedBaseProperty", - "id", - "name", - "photo", - "photo2" - ] - } - } -} \ No newline at end of file diff --git a/testdata/schema_alias.json b/testdata/schema_alias.json deleted file mode 100644 index 75d11a7..0000000 --- a/testdata/schema_alias.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/alias-object-b", - "$ref": "#/$defs/AliasObjectA", - "$defs": { - "AliasObjectA": { - "properties": { - "prop_a": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "prop_a" - ] - } - } -} \ No newline at end of file diff --git a/testdata/schema_alias_2.json b/testdata/schema_alias_2.json deleted file mode 100644 index f7f90ba..0000000 --- a/testdata/schema_alias_2.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/alias-object-c", - "$ref": "#/$defs/AliasObjectC", - "$defs": { - "AliasObjectA": { - "properties": { - "prop_a": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "prop_a" - ] - }, - "AliasObjectC": { - "properties": { - "obj_b": { - "$ref": "#/$defs/AliasObjectA" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "obj_b" - ] - } - } -} \ No newline at end of file diff --git a/testdata/schema_property_alias.json b/testdata/schema_property_alias.json deleted file mode 100644 index 287ac94..0000000 --- a/testdata/schema_property_alias.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/alias-property-object-base", - "$ref": "#/$defs/AliasPropertyObjectBase", - "$defs": { - "AliasObjectA": { - "properties": { - "prop_a": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "prop_a" - ] - }, - "AliasPropertyObjectBase": { - "properties": { - "object": { - "$ref": "#/$defs/AliasObjectA" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "object" - ] - } - } -} \ No newline at end of file diff --git a/testdata/schema_with_expression.json b/testdata/schema_with_expression.json deleted file mode 100644 index d729e6f..0000000 --- a/testdata/schema_with_expression.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/expression", - "$ref": "#/$defs/Expression", - "$defs": { - "Expression": { - "properties": { - "value": { - "type": "integer", - "foo": "bar=='baz'" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "value" - ] - } - } -} \ No newline at end of file diff --git a/testdata/schema_with_minimum.json b/testdata/schema_with_minimum.json deleted file mode 100644 index 42e6bad..0000000 --- a/testdata/schema_with_minimum.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/min-value", - "$ref": "#/$defs/MinValue", - "$defs": { - "MinValue": { - "properties": { - "value4": { - "type": "integer", - "minimum": 0 - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "value4" - ] - } - } -} \ No newline at end of file diff --git a/testdata/test_config.json b/testdata/test_config.json deleted file mode 100644 index 7cfb67f..0000000 --- a/testdata/test_config.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/config", - "$ref": "#/$defs/Config", - "$defs": { - "Config": { - "properties": { - "name": { - "type": "string" - }, - "count": { - "type": "integer" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name", - "count" - ] - } - } -} \ No newline at end of file diff --git a/testdata/test_description_override.json b/testdata/test_description_override.json deleted file mode 100644 index 2f7d780..0000000 --- a/testdata/test_description_override.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-description-override", - "$ref": "#/$defs/TestDescriptionOverride", - "$defs": { - "TestDescriptionOverride": { - "properties": { - "FirstName": { - "type": "string", - "description": "test2" - }, - "LastName": { - "type": "string", - "description": "test3" - }, - "age": { - "type": "integer", - "description": "test4" - }, - "middle_name": { - "type": "string", - "description": "test5" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "FirstName", - "LastName", - "age" - ] - } - } -} \ No newline at end of file diff --git a/testdata/test_user.json b/testdata/test_user.json deleted file mode 100644 index 3ddd64f..0000000 --- a/testdata/test_user.json +++ /dev/null @@ -1,231 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "MapType": { - "type": "object" - }, - "TestUser": { - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] - } - } -} \ No newline at end of file diff --git a/testdata/test_user_assign_anchor.json b/testdata/test_user_assign_anchor.json deleted file mode 100644 index 31bc3c4..0000000 --- a/testdata/test_user_assign_anchor.json +++ /dev/null @@ -1,233 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "Bytes": { - "type": "string", - "contentEncoding": "base64" - }, - "GrandfatherType": { - "$anchor": "GrandfatherType", - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "family_name" - ] - }, - "MapType": { - "type": "object" - }, - "TestUser": { - "$anchor": "TestUser", - "properties": { - "id": { - "type": "integer" - }, - "some_base_property": { - "type": "integer" - }, - "grand": { - "$ref": "#/$defs/GrandfatherType" - }, - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "PublicNonExported": { - "type": "integer" - }, - "MapType": { - "$ref": "#/$defs/MapType" - }, - "name": { - "type": "string", - "maxLength": 20, - "minLength": 1, - "pattern": ".*", - "title": "the name", - "description": "this is a property", - "default": "alex", - "readOnly": true, - "examples": [ - "joe", - "lucy" - ] - }, - "password": { - "type": "string", - "writeOnly": true - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "list of IDs, omitted when empty" - }, - "tags": { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - "options": { - "type": "object" - }, - "TestFlag": { - "type": "boolean" - }, - "TestFlagFalse": { - "type": "boolean", - "default": false - }, - "TestFlagTrue": { - "type": "boolean", - "default": true - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "website": { - "type": "string", - "format": "uri" - }, - "network_address": { - "type": "string", - "format": "ipv4" - }, - "photo": { - "type": "string", - "contentEncoding": "base64" - }, - "photo2": { - "$ref": "#/$defs/Bytes" - }, - "feeling": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - } - ] - }, - "age": { - "type": "integer", - "maximum": 120, - "exclusiveMaximum": 121, - "minimum": 18, - "exclusiveMinimum": 17 - }, - "email": { - "type": "string", - "format": "email" - }, - "uuid": { - "type": "string", - "format": "uuid" - }, - "Baz": { - "type": "string", - "foo": [ - "bar", - "bar1" - ], - "hello": "world" - }, - "bool_extra": { - "type": "string", - "isFalse": false, - "isTrue": true - }, - "color": { - "type": "string", - "enum": [ - "red", - "green", - "blue" - ] - }, - "rank": { - "type": "integer", - "enum": [ - 1, - 2, - 3 - ] - }, - "mult": { - "type": "number", - "enum": [ - 1.0, - 1.5, - 2.0 - ] - }, - "roles": { - "items": { - "type": "string", - "enum": [ - "admin", - "moderator", - "user" - ] - }, - "type": "array" - }, - "priorities": { - "items": { - "type": "integer", - "enum": [ - -1, - 0, - 1 - ] - }, - "type": "array" - }, - "offsets": { - "items": { - "type": "number", - "enum": [ - 1.570796, - 3.141592, - 6.283185 - ] - }, - "type": "array" - }, - "anything": true, - "raw": true - }, - "additionalProperties": false, - "type": "object", - "required": [ - "id", - "some_base_property", - "grand", - "SomeUntaggedBaseProperty", - "PublicNonExported", - "MapType", - "name", - "password", - "TestFlag", - "photo", - "photo2", - "age", - "email", - "uuid", - "Baz", - "color", - "roles", - "raw" - ] - } - } -} \ No newline at end of file diff --git a/testdata/test_yaml_and_json_prefer_yaml.json b/testdata/test_yaml_and_json_prefer_yaml.json deleted file mode 100644 index a12b344..0000000 --- a/testdata/test_yaml_and_json_prefer_yaml.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/test-yaml-and-json", - "$ref": "#/$defs/TestYamlAndJson", - "$defs": { - "TestYamlAndJson": { - "properties": { - "first_name": { - "type": "string" - }, - "LastName": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "middle_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "first_name", - "LastName", - "age" - ] - } - } -} \ No newline at end of file diff --git a/testdata/unsigned_int_handling.json b/testdata/unsigned_int_handling.json deleted file mode 100644 index 5acb706..0000000 --- a/testdata/unsigned_int_handling.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/unsigned-int-handler", - "$ref": "#/$defs/UnsignedIntHandler", - "$defs": { - "UnsignedIntHandler": { - "properties": { - "min_len": { - "items": { - "type": "string", - "minLength": 0 - }, - "type": "array" - }, - "max_len": { - "items": { - "type": "string", - "maxLength": 0 - }, - "type": "array" - }, - "min_items": { - "items": { - "type": "string" - }, - "type": "array", - "minItems": 0 - }, - "max_items": { - "items": { - "type": "string" - }, - "type": "array", - "maxItems": 0 - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "min_len", - "max_len", - "min_items", - "max_items" - ] - } - } -} \ No newline at end of file diff --git a/testdata/user_with_anchor.json b/testdata/user_with_anchor.json deleted file mode 100644 index 5663e66..0000000 --- a/testdata/user_with_anchor.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/user-with-anchor", - "$ref": "#/$defs/UserWithAnchor", - "$defs": { - "UserWithAnchor": { - "properties": { - "name": { - "$anchor": "Name", - "type": "string" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name" - ] - } - } -} \ No newline at end of file diff --git a/testdata/with_custom_format.json b/testdata/with_custom_format.json deleted file mode 100644 index 557791e..0000000 --- a/testdata/with_custom_format.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/conneroisu/groq-go/with-custom-format", - "$ref": "#/$defs/WithCustomFormat", - "$defs": { - "WithCustomFormat": { - "properties": { - "dates": { - "items": { - "type": "string", - "format": "date" - }, - "type": "array" - }, - "odds": { - "items": { - "type": "string", - "format": "odd" - }, - "type": "array" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "dates", - "odds" - ] - } - } -} \ No newline at end of file diff --git a/testdata/yaml_inline.json b/testdata/yaml_inline.json deleted file mode 100644 index 7b396a9..0000000 --- a/testdata/yaml_inline.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/TestYamlInline", - "definitions": { - "Inner": { - "required": ["foo"], - "properties": { - "foo": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object" - }, - "TestYamlInline": { - "required": [ - "Inlined" - ], - "properties": { - "Inlined": { - "$schema": "http://json-schema.org/draft-04/schema#", - "$ref": "#/definitions/Inner" - } - }, - "additionalProperties": false, - "type": "object" - } - } -} diff --git a/unit_test.go b/unit_test.go index 77f90f8..19e295b 100644 --- a/unit_test.go +++ b/unit_test.go @@ -19,12 +19,15 @@ import ( "testing" "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/pkg/groqerr" + "github.com/conneroisu/groq-go/pkg/models" + "github.com/conneroisu/groq-go/pkg/moderation" "github.com/conneroisu/groq-go/pkg/test" "github.com/stretchr/testify/assert" ) func TestTestServer(t *testing.T) { - if !test.IsUnitTest() { + if !test.IsIntegrationTest() { t.Skip() } num := rand.Intn(100) @@ -35,7 +38,7 @@ func TestTestServer(t *testing.T) { strm, err := client.CreateChatCompletionStream( ctx, groq.ChatCompletionRequest{ - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -55,7 +58,6 @@ You have a six-sided die that you roll once. Let $R{i}$ denote the event that th }, ) a.NoError(err, "CreateCompletionStream error") - i := 0 for { i++ @@ -75,21 +77,21 @@ func TestModerate(t *testing.T) { "/v1/chat/completions", handleModerationEndpoint, ) - mod, err := client.Moderate(context.Background(), groq.ModerationRequest{ - Model: groq.ModelLlamaGuard38B, - Messages: []groq.ChatCompletionMessage{ + mod, err := client.Moderate(context.Background(), + []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, Content: "I want to kill them.", }, }, - }) + models.ModelLlamaGuard38B, + ) a := assert.New(t) a.NoError(err, "Moderation error") a.Equal(true, mod.Flagged) a.Contains( mod.Categories, - groq.CategoryViolentCrimes, + moderation.CategoryViolentCrimes, ) } @@ -99,7 +101,7 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { ID: "chatcmpl-123", Object: "chat.completion", Created: 1693721698, - Model: groq.ChatModel(groq.ModelLlamaGuard38B), + Model: models.ChatModel(models.ModelLlamaGuard38B), Choices: []groq.ChatCompletionChoice{ { Message: groq.ChatCompletionMessage{ @@ -130,7 +132,6 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { return } } - func setupGroqTestServer() ( client *groq.Client, server *test.ServerTest, @@ -149,7 +150,6 @@ func setupGroqTestServer() ( } return } - func TestEmptyKeyClientCreation(t *testing.T) { client, err := groq.NewClient("") a := assert.New(t) @@ -166,30 +166,25 @@ func TestCreateChatCompletionStream(t *testing.T) { "/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - // Send test responses dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: done\n")...) dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -201,13 +196,12 @@ func TestCreateChatCompletionStream(t *testing.T) { ) a.NoError(err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []groq.ChatCompletionStreamResponse{ { ID: "1", Object: "completion", Created: 1598069254, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, SystemFingerprint: "fp_d9767fc5b9", Choices: []groq.ChatCompletionStreamChoice{ { @@ -222,7 +216,7 @@ func TestCreateChatCompletionStream(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, SystemFingerprint: "fp_d9767fc5b9", Choices: []groq.ChatCompletionStreamChoice{ { @@ -234,14 +228,12 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, } - for ix, expectedResponse := range expectedResponses { b, _ := json.Marshal(expectedResponse) t.Logf("%d: %s", ix, string(b)) - receivedResponse, streamErr := stream.Recv() a.NoError(streamErr, "stream.Recv() failed") - if !compareChatResponses(t, expectedResponse, receivedResponse) { + if !compareChatResponses(t, expectedResponse, *receivedResponse) { t.Errorf( "Stream response %v is %v, expected %v", ix, @@ -250,19 +242,15 @@ func TestCreateChatCompletionStream(t *testing.T) { ) } } - _, streamErr := stream.Recv() if !errors.Is(streamErr, io.EOF) { t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) } - _, streamErr = stream.Recv() if !errors.Is(streamErr, io.EOF) { t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) } - _, streamErr = stream.Recv() - a.ErrorIs( streamErr, io.EOF, @@ -286,7 +274,6 @@ func TestCreateChatCompletionStreamError(t *testing.T) { "/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - // Send test responses dataBytes := []byte{} dataStr := []string{ @@ -302,17 +289,15 @@ func TestCreateChatCompletionStreamError(t *testing.T) { for _, str := range dataStr { dataBytes = append(dataBytes, []byte(str+"\n")...) } - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -324,17 +309,14 @@ func TestCreateChatCompletionStreamError(t *testing.T) { ) a.NoError(err, "CreateCompletionStream returned error") defer stream.Close() - _, streamErr := stream.Recv() a.Error(streamErr, "stream.Recv() did not return error") - - var apiErr *groq.APIError + var apiErr *groqerr.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } t.Logf("%+v\n", apiErr) } - func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { a := assert.New(t) client, server, teardown := setupGroqTestServer() @@ -346,23 +328,20 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set(xCustomHeader, xCustomHeaderValue) - // Send test responses dataBytes := []byte( `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, ) dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -374,13 +353,11 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { ) a.NoError(err, "CreateCompletionStream returned error") defer stream.Close() - value := stream.Header.Get(xCustomHeader) if value != xCustomHeaderValue { t.Errorf("expected %s to be %s", xCustomHeaderValue, value) } } - func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { client, server, teardown := setupGroqTestServer() a := assert.New(t) @@ -405,23 +382,20 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { w.Header().Set(k, fmt.Sprintf("%s", v)) } } - // Send test responses dataBytes := []byte( `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, ) dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -433,7 +407,6 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { ) a.NoError(err, "CreateCompletionStream returned error") defer stream.Close() - headers := newRateLimitHeaders(stream.Header) bs1, _ := json.Marshal(headers) bs2, _ := json.Marshal(rateLimitHeaders) @@ -457,7 +430,6 @@ func newRateLimitHeaders(h http.Header) groq.RateLimitHeaders { ResetTokens: groq.ResetTime(h.Get("x-ratelimit-reset-tokens")), } } - func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { a := assert.New(t) client, server, teardown := setupGroqTestServer() @@ -466,23 +438,20 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { "/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - // Send test responses dataBytes := []byte( `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, ) dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -494,17 +463,14 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { ) a.NoError(err, "CreateCompletionStream returned error") defer stream.Close() - _, streamErr := stream.Recv() a.Error(streamErr, "stream.Recv() did not return error") - - var apiErr *groq.APIError + var apiErr *groqerr.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } t.Logf("%+v\n", apiErr) } - func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { a := assert.New(t) client, server, teardown := setupGroqTestServer() @@ -514,14 +480,12 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) - // Send test responses dataBytes := []byte(`{"error":{` + `"message": "You are sending requests too quickly.",` + `"type":"rate_limit_reached",` + `"param":null,` + `"code":"rate_limit_reached"}}`) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, @@ -530,7 +494,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -540,7 +504,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { Stream: true, }, ) - var apiErr *groq.APIError + var apiErr *groqerr.APIError if !errors.As(err, &apiErr) { t.Errorf( "TestCreateChatCompletionStreamRateLimitError did not return APIError", @@ -548,40 +512,32 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { } t.Logf("%+v\n", apiErr) } - func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { a := assert.New(t) client, server, teardown := setupGroqTestServer() defer teardown() - server.RegisterHandler( "/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - // Send test responses var dataBytes []byte data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - data = `{"id":"3","object":"completion","created":1598069256,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - _, err := w.Write(dataBytes) a.NoError(err, "Write error") }, ) - stream, err := client.CreateChatCompletionStream( context.Background(), groq.ChatCompletionRequest{ MaxTokens: 5, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, Messages: []groq.ChatCompletionMessage{ { Role: groq.ChatMessageRoleUser, @@ -601,7 +557,7 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { ID: "1", Object: "completion", Created: 1598069254, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, SystemFingerprint: "fp_d9767fc5b9", Choices: []groq.ChatCompletionStreamChoice{ { @@ -616,7 +572,7 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, SystemFingerprint: "fp_d9767fc5b9", Choices: []groq.ChatCompletionStreamChoice{ { @@ -631,7 +587,7 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { ID: "3", Object: "completion", Created: 1598069256, - Model: groq.ModelLlama38B8192, + Model: models.ModelLlama38B8192, SystemFingerprint: "fp_d9767fc5b9", Choices: []groq.ChatCompletionStreamChoice{}, Usage: &groq.Usage{ @@ -641,17 +597,15 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { }, }, } - for ix, expectedResponse := range expectedResponses { ix++ b, _ := json.Marshal(expectedResponse) t.Logf("%d: %s", ix, string(b)) - receivedResponse, streamErr := stream.Recv() if !errors.Is(streamErr, io.EOF) { a.NoError(streamErr, "stream.Recv() failed") } - if !compareChatResponses(t, expectedResponse, receivedResponse) { + if !compareChatResponses(t, expectedResponse, *receivedResponse) { t.Errorf( "Stream response %v: %v,BUT expected %v", ix, @@ -660,14 +614,11 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { ) } } - _, streamErr := stream.Recv() if !errors.Is(streamErr, io.EOF) { t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) } - _, streamErr = stream.Recv() - a.ErrorIs( streamErr, io.EOF, @@ -720,7 +671,6 @@ func compareChatResponses( } return true } - func compareChatStreamResponseChoices( c1, c2 groq.ChatCompletionStreamChoice, ) bool { @@ -742,7 +692,6 @@ func TestAudio(t *testing.T) { defer teardown() server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - testcases := []struct { name string createFn func(context.Context, groq.AudioRequest) (groq.AudioResponse, error) @@ -756,44 +705,37 @@ func TestAudio(t *testing.T) { client.CreateTranslation, }, } - ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) defer cleanup() - a := assert.New(t) for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := groq.AudioRequest{ FilePath: path, - Model: groq.ModelWhisperLargeV3, + Model: models.ModelWhisperLargeV3, } _, err := tc.createFn(ctx, req) a.NoError(err, "audio API error") }) - t.Run(tc.name+" (with reader)", func(t *testing.T) { req := groq.AudioRequest{ FilePath: "fake.webm", Reader: bytes.NewBuffer([]byte(`some webm binary data`)), - Model: groq.ModelWhisperLargeV3, + Model: models.ModelWhisperLargeV3, } _, err := tc.createFn(ctx, req) a.NoError(err, "audio API error") }) } } - func TestAudioWithOptionalArgs(t *testing.T) { client, server, teardown := setupGroqTestServer() defer teardown() server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - testcases := []struct { name string createFn func(context.Context, groq.AudioRequest) (groq.AudioResponse, error) @@ -807,25 +749,21 @@ func TestAudioWithOptionalArgs(t *testing.T) { client.CreateTranslation, }, } - ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { a := assert.New(t) path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := groq.AudioRequest{ FilePath: path, - Model: groq.ModelWhisperLargeV3, + Model: models.ModelWhisperLargeV3, Prompt: "用简体中文", Temperature: 0.5, Language: "zh", - Format: groq.AudioResponseFormatSRT, + Format: groq.FormatSRT, } _, err := tc.createFn(ctx, req) a.NoError(err, "audio API error") @@ -836,28 +774,23 @@ func TestAudioWithOptionalArgs(t *testing.T) { // handleAudioEndpoint Handles the completion endpoint by the test server. func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { var err error - // audio endpoints only accept POST requests if r.Method != "POST" { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } - mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { http.Error(w, "failed to parse media type", http.StatusBadRequest) return } - if !strings.HasPrefix(mediaType, "multipart") { http.Error(w, "request is not multipart", http.StatusBadRequest) } - boundary, ok := params["boundary"] if !ok { http.Error(w, "no boundary in params", http.StatusBadRequest) return } - fileData := &bytes.Buffer{} mr := multipart.NewReader(r.Body, boundary) part, err := mr.NextPart() @@ -869,13 +802,11 @@ func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "failed to copy file", http.StatusInternalServerError) return } - if len(fileData.Bytes()) == 0 { w.WriteHeader(http.StatusInternalServerError) http.Error(w, "received empty file data", http.StatusBadRequest) return } - if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { http.Error(w, "failed to write body", http.StatusInternalServerError) return