diff --git a/audio.go b/audio.go index ca135a3..ffba280 100644 --- a/audio.go +++ b/audio.go @@ -71,15 +71,17 @@ func (r *audioTextResponse) SetHeader(header http.Header) { r.header = header } -// ToAudioResponse converts the audio text response to an audio response. -func (r *audioTextResponse) ToAudioResponse() AudioResponse { +// toAudioResponse converts the audio text response to an audio response. +func (r *audioTextResponse) toAudioResponse() AudioResponse { return AudioResponse{ Text: r.Text, Header: r.header, } } -// CreateTranscription — API call to create a transcription. Returns transcribed text. +// CreateTranscription calls the transcriptions endpoint with the given request. +// +// Returns transcribed text in the response_format specified in the request. func (c *Client) CreateTranscription( ctx context.Context, request AudioRequest, @@ -87,7 +89,9 @@ func (c *Client) CreateTranscription( return c.callAudioAPI(ctx, request, transcriptionsSuffix) } -// CreateTranslation — API call to translate audio into English. +// CreateTranslation calls the translations endpoint with the given request. +// +// Returns the translated text in the response_format specified in the request. func (c *Client) CreateTranslation( ctx context.Context, request AudioRequest, @@ -95,7 +99,6 @@ func (c *Client) CreateTranslation( return c.callAudioAPI(ctx, request, translationsSuffix) } -// callAudioAPI — API call to an audio endpoint. func (c *Client) callAudioAPI( ctx context.Context, request AudioRequest, @@ -118,12 +121,12 @@ func (c *Client) callAudioAPI( return AudioResponse{}, err } - if request.HasJSONResponse() { + if request.hasJSONResponse() { err = c.sendRequest(req, &response) } else { var textResponse audioTextResponse err = c.sendRequest(req, &textResponse) - response = textResponse.ToAudioResponse() + response = textResponse.toAudioResponse() } if err != nil { return AudioResponse{}, err @@ -131,25 +134,22 @@ func (c *Client) callAudioAPI( return } -// HasJSONResponse returns true if the response format is JSON. -func (r AudioRequest) HasJSONResponse() bool { +func (r AudioRequest) hasJSONResponse() bool { return r.Format == "" || r.Format == AudioResponseFormatJSON || r.Format == AudioResponseFormatVerboseJSON } // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b FormBuilder) error { +func audioMultipartForm(request AudioRequest, b formBuilder) error { err := createFileField(request, b) if err != nil { return err } - err = b.WriteField("model", string(request.Model)) if err != nil { return fmt.Errorf("writing model name: %w", err) } - // Create a form field for the prompt (if provided) if request.Prompt != "" { err = b.WriteField("prompt", request.Prompt) @@ -157,7 +157,6 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { return fmt.Errorf("writing prompt: %w", err) } } - // Create a form field for the format (if provided) if request.Format != "" { err = b.WriteField("response_format", string(request.Format)) @@ -165,7 +164,6 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { return fmt.Errorf("writing format: %w", err) } } - // Create a form field for the temperature (if provided) if request.Temperature != 0 { err = b.WriteField( @@ -176,7 +174,6 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { return fmt.Errorf("writing temperature: %w", err) } } - // Create a form field for the language (if provided) if request.Language != "" { err = b.WriteField("language", request.Language) @@ -184,7 +181,6 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { return fmt.Errorf("writing language: %w", err) } } - if len(request.TimestampGranularities) > 0 { for _, tg := range request.TimestampGranularities { err = b.WriteField("timestamp_granularities[]", string(tg)) @@ -193,14 +189,13 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error { } } } - // Close the multipart writer return b.Close() } // createFileField creates the "file" form field from either an existing file or by using the reader. func createFileField( request AudioRequest, - b FormBuilder, + b formBuilder, ) (err error) { if request.Reader != nil { err := b.CreateFormFileReader("file", request.Reader, request.FilePath) diff --git a/builders.go b/builders.go index f5ab66c..42f3b36 100644 --- a/builders.go +++ b/builders.go @@ -12,8 +12,8 @@ import ( "path" ) -// FormBuilder is an interface for building a form. -type FormBuilder interface { +// formBuilder is an interface for building a form. +type formBuilder interface { CreateFormFile(fieldname string, file *os.File) error CreateFormFileReader(fieldname string, r io.Reader, filename string) error WriteField(fieldname, value string) error @@ -21,28 +21,26 @@ type FormBuilder interface { FormDataContentType() string } -// DefaultFormBuilder is a default implementation of FormBuilder. -type DefaultFormBuilder struct { +// defaultFormBuilder is a default implementation of FormBuilder. +type defaultFormBuilder struct { writer *multipart.Writer } -// NewFormBuilder creates a new DefaultFormBuilder. -func NewFormBuilder(body io.Writer) *DefaultFormBuilder { - return &DefaultFormBuilder{ +// newFormBuilder creates a new DefaultFormBuilder. +func newFormBuilder(body io.Writer) *defaultFormBuilder { + return &defaultFormBuilder{ writer: multipart.NewWriter(body), } } -// CreateFormFile creates a form file. -func (fb *DefaultFormBuilder) CreateFormFile( +func (fb *defaultFormBuilder) CreateFormFile( fieldname string, file *os.File, ) error { return fb.createFormFile(fieldname, file, file.Name()) } -// CreateFormFileReader creates a form file from a reader. -func (fb *DefaultFormBuilder) CreateFormFileReader( +func (fb *defaultFormBuilder) CreateFormFileReader( fieldname string, r io.Reader, filename string, @@ -50,8 +48,7 @@ func (fb *DefaultFormBuilder) CreateFormFileReader( return fb.createFormFile(fieldname, r, path.Base(filename)) } -// createFormFile creates a form file. -func (fb *DefaultFormBuilder) createFormFile( +func (fb *defaultFormBuilder) createFormFile( fieldname string, r io.Reader, filename string, @@ -71,23 +68,19 @@ func (fb *DefaultFormBuilder) createFormFile( return nil } -// WriteField writes a field to the form. -func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { +func (fb *defaultFormBuilder) WriteField(fieldname, value string) error { return fb.writer.WriteField(fieldname, value) } -// Close closes the form. -func (fb *DefaultFormBuilder) Close() error { +func (fb *defaultFormBuilder) Close() error { return fb.writer.Close() } -// FormDataContentType returns the content type of the form. -func (fb *DefaultFormBuilder) FormDataContentType() string { +func (fb *defaultFormBuilder) FormDataContentType() string { return fb.writer.FormDataContentType() } -// RequestBuilder is an interface that defines the Build method. -type RequestBuilder interface { +type requestBuilder interface { Build( ctx context.Context, method, url string, @@ -96,33 +89,13 @@ type RequestBuilder interface { ) (*http.Request, error) } -// HTTPRequestBuilder is a struct that implements the RequestBuilder interface. -type HTTPRequestBuilder struct { - marshaller Marshaller -} - -// Marshaller is an interface that defines the Marshal method. -type Marshaller interface { - Marshal(v any) ([]byte, error) -} - -// JSONMarshaller is a struct that implements the Marshaller interface. -type JSONMarshaller struct{} +type httpRequestBuilder struct{} -// Marshal marshals the given value to JSON. -func (j *JSONMarshaller) Marshal(v any) ([]byte, error) { - return json.Marshal(v) -} - -// NewRequestBuilder returns a new HTTPRequestBuilder. -func NewRequestBuilder() *HTTPRequestBuilder { - return &HTTPRequestBuilder{ - marshaller: &JSONMarshaller{}, - } +func newRequestBuilder() *httpRequestBuilder { + return &httpRequestBuilder{} } -// Build builds a new request. -func (b *HTTPRequestBuilder) Build( +func (b *httpRequestBuilder) Build( ctx context.Context, method string, url string, @@ -135,7 +108,7 @@ func (b *HTTPRequestBuilder) Build( bodyReader = v } else { var reqBytes []byte - reqBytes, err = b.marshaller.Marshal(body) + reqBytes, err = json.Marshal(body) if err != nil { return } diff --git a/builders_test.go b/builders_test.go index c53cce0..f5f819c 100644 --- a/builders_test.go +++ b/builders_test.go @@ -6,6 +6,7 @@ package groq // testing private field import ( "bytes" "context" + "encoding/json" "errors" "io" "net/http" @@ -79,7 +80,7 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { defer file.Close() defer os.Remove(file.Name()) - builder := NewFormBuilder(&failingWriter{}) + builder := newFormBuilder(&failingWriter{}) err = builder.CreateFormFile("file", file) a.ErrorIs( err, @@ -102,7 +103,7 @@ func TestFormBuilderWithClosedFile(t *testing.T) { defer os.Remove(file.Name()) body := &bytes.Buffer{} - builder := NewFormBuilder(body) + builder := newFormBuilder(body) err = builder.CreateFormFile("file", file) a.Error(err, "formbuilder should return error if file is closed") a.ErrorIs( @@ -115,13 +116,13 @@ func TestFormBuilderWithClosedFile(t *testing.T) { // TestRequestBuilderReturnsRequest tests the request builder returns a // request. func TestRequestBuilderReturnsRequest(t *testing.T) { - b := NewRequestBuilder() + b := newRequestBuilder() var ( ctx = context.Background() method = http.MethodPost url = "/foo" request = map[string]string{"foo": "bar"} - reqBytes, _ = b.marshaller.Marshal(request) + reqBytes, _ = json.Marshal(request) want, _ = http.NewRequestWithContext( ctx, method, @@ -146,7 +147,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { url = "/foo" want, _ = http.NewRequestWithContext(ctx, method, url, nil) ) - b := NewRequestBuilder() + b := newRequestBuilder() got, _ := b.Build(ctx, method, url, nil, nil) if !reflect.DeepEqual(got, want) { t.Errorf("Build() got = %v, want %v", got, want) diff --git a/client.go b/client.go index ea764f4..9e8fa56 100644 --- a/client.go +++ b/client.go @@ -39,9 +39,9 @@ type Client struct { baseURL string // Base URL for the client. client *http.Client // Client is the HTTP client to use EmptyMessagesLimit uint // EmptyMessagesLimit is the limit for the empty messages. - requestBuilder RequestBuilder - requestFormBuilder FormBuilder - createFormBuilder func(body io.Writer) FormBuilder + requestBuilder requestBuilder + requestFormBuilder formBuilder + createFormBuilder func(body io.Writer) formBuilder logger zerolog.Logger // Logger is the logger for the client. } @@ -57,10 +57,10 @@ func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) { Logger(), baseURL: groqAPIURLv1, EmptyMessagesLimit: 10, - createFormBuilder: func(body io.Writer) FormBuilder { - return NewFormBuilder(body) + createFormBuilder: func(body io.Writer) formBuilder { + return newFormBuilder(body) }, - requestBuilder: NewRequestBuilder(), + requestBuilder: newRequestBuilder(), } for _, opt := range opts { opt(c)