diff --git a/examples/pgvector101/hello-world.go b/examples/pgvector101/hello-world.go deleted file mode 100644 index 37e86ef..0000000 --- a/examples/pgvector101/hello-world.go +++ /dev/null @@ -1,133 +0,0 @@ -// adopted from https://github.com/pgvector/pgvector-go/blob/master/examples/openai/main.go -package main - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "time" - - "github.com/edgeflare/pgo/pkg/util" - "github.com/edgeflare/pgo/pkg/util/httpclient" - "github.com/jackc/pgx/v5" - "github.com/pgvector/pgvector-go" - pgxvector "github.com/pgvector/pgvector-go/pgx" -) - -var ( - apiUrl = "http://127.0.0.1:11434/v1/embeddings" // ollama - apikey = util.GetEnvOrDefault("API_KEY", "") - modelId = "llama3.2:latest" - dimensions = 3072 // for llama3.2. 1536 for openai -) - -func main() { - ctx := context.Background() - - conn, err := pgx.Connect(ctx, "postgres://postgres:secret@localhost:5432/postgres") - if err != nil { - panic(err) - } - defer conn.Close(ctx) - - _, err = conn.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector") - if err != nil { - panic(err) - } - - err = pgxvector.RegisterTypes(ctx, conn) - if err != nil { - panic(err) - } - - _, err = conn.Exec(ctx, "DROP TABLE IF EXISTS documents") - if err != nil { - panic(err) - } - - _, err = conn.Exec(ctx, fmt.Sprintf("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(%v))", dimensions)) - if err != nil { - panic(err) - } - - input := []string{ - "The dog is barking", - "The cat is purring", - "The bear is growling", - } - embeddings, err := FetchEmbeddings(input, apikey) - if err != nil { - panic(err) - } - - for i, content := range input { - _, err := conn.Exec(ctx, "INSERT INTO documents (content, embedding) VALUES ($1, $2)", content, pgvector.NewVector(embeddings[i])) - if err != nil { - panic(err) - } - } - - documentId := 1 - rows, err := conn.Query(ctx, "SELECT id, content FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5", documentId) - if err != nil { - panic(err) - } - defer rows.Close() - - for rows.Next() { - var id int64 - var content string - err = rows.Scan(&id, &content) - if err != nil { - panic(err) - } - fmt.Println(id, content) - } - - if rows.Err() != nil { - panic(rows.Err()) - } -} - -type apiRequest struct { - Input []string `json:"input"` - Model string `json:"model"` -} - -type apiResponse struct { - Data []struct { - Embedding []float32 `json:"embedding"` - } `json:"data"` -} - -func FetchEmbeddings(input []string, apiKey string) ([][]float32, error) { - data := &apiRequest{ - Input: input, - Model: modelId, - } - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - headers := map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", apiKey)}, - } - - body, err := httpclient.Request(ctx, http.MethodPost, apiUrl, data, headers) - if err != nil { - log.Fatal(err) - } - - var response apiResponse - if err := json.Unmarshal(body, &response); err != nil { - return nil, err - } - - var embeddings [][]float32 - for _, item := range response.Data { - embeddings = append(embeddings, item.Embedding) - } - return embeddings, nil -} diff --git a/examples/rag101/README.md b/examples/rag101/README.md deleted file mode 100644 index 32492f2..0000000 --- a/examples/rag101/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Retrieval-Augmented Generation (RAG) - -![RAG](../../docs/rag.svg) \ No newline at end of file diff --git a/examples/rag101/main.go b/examples/rag101/main.go index 76795a7..26bca5d 100644 --- a/examples/rag101/main.go +++ b/examples/rag101/main.go @@ -2,100 +2,49 @@ package main import ( "context" - "encoding/json" "fmt" "log" "github.com/edgeflare/pgo/pkg/rag" + "github.com/edgeflare/pgo/pkg/util" "github.com/jackc/pgx/v5" ) func main() { - // insertEmbeddings() - generateWithRetrieval() -} - -func generateWithRetrieval() { - // Database connection parameters - connConfig, err := pgx.ParseConfig("postgres://postgres:secret@localhost:5431/testdb") - if err != nil { - log.Fatalf("unable to parse connection config: %v", err) - } - ctx := context.Background() - // Create a new connection to the database - conn, err := pgx.ConnectConfig(ctx, connConfig) + conn, err := pgx.Connect(ctx, util.GetEnvOrDefault("DATABASE_URL", "postgres://postgres:secret@localhost:5432/postgres")) if err != nil { - log.Fatalf("unable to connect to database: %v", err) + log.Fatalf("Failed to connect to database: %v", err) } defer conn.Close(ctx) - // Create a new RAG client with default configuration - client := rag.NewClient(conn, rag.DefaultConfig()) + // Create a new RAG client + client, err := rag.NewClient(conn, rag.DefaultConfig()) + client.Config.TableName = "example_table" - response, err := client.Generate(ctx, "what is the cat doing?") if err != nil { - log.Fatalln(err) + log.Fatalf("Failed to create RAG client: %v", err) } - fmt.Printf("%s", response) -} - -func insertEmbeddings() { - // Database connection parameters - connConfig, err := pgx.ParseConfig("postgres://postgres:secret@localhost:5431/testdb") + err = client.CreateEmbedding(ctx, "") if err != nil { - log.Fatalf("unable to parse connection config: %v", err) + log.Fatalf("Failed to create embeddings: %v", err) } - ctx := context.Background() + fmt.Println("Embeddings have been successfully created.") - // Create a new connection to the database - conn, err := pgx.ConnectConfig(ctx, connConfig) - if err != nil { - log.Fatalf("unable to connect to database: %v", err) - } - defer conn.Close(ctx) + // retrieval example + input := "example input text" + limit := 2 - // Create a new RAG client with default configuration - client := rag.NewClient(conn, rag.DefaultConfig()) - - // Sample input texts - input := []string{ - "The dog is barking", - "The cat is purring", - "The bear is growling", - // Add more examples here - } - - // Fetch embeddings from the API - embeddingsResponse, err := client.FetchEmbeddings(ctx, input) - if err != nil { - log.Fatalf("failed to fetch embeddings: %v", err) - } - - // Create metadata for all embeddings - metadata := json.RawMessage(`{ - "source": "sample", - "batch": "animal_sounds", - "timestamp": "2024-10-19" - }`) - - // Create tags for all embeddings - tags := []string{"animal", "sound", "example"} - - // Create vector embeddings using the helper function - embeddings, err := rag.ToVectorEmbedding(input, embeddingsResponse, tags, metadata) + results, err := client.Retrieve(ctx, input, limit) if err != nil { - log.Fatalf("failed to create embeddings: %v", err) + log.Fatalf("Failed to retrieve content: %v", err) } - // Insert embeddings with progress tracking - err = client.InsertEmbeddings(ctx, embeddings, 1) - if err != nil { - log.Fatalf("failed to insert embeddings: %v", err) + // Print the retrieved results + for _, r := range results { + fmt.Printf("ID: %v\nContent: %s\nEmbedding: %v\n", r.PK, r.Content, r.Embedding.Slice()[0]) } - - fmt.Println("\nAll embeddings inserted successfully!") } diff --git a/pkg/rag/client.go b/pkg/rag/client.go index 466ecf7..8173b7b 100644 --- a/pkg/rag/client.go +++ b/pkg/rag/client.go @@ -1,40 +1,152 @@ package rag import ( + "context" + "fmt" + "github.com/edgeflare/pgo/pkg/util" "github.com/jackc/pgx/v5" + pgxvector "github.com/pgvector/pgvector-go/pgx" + "go.uber.org/zap" ) // Config holds the configuration for the RAG package type Config struct { - TableName string - Dimensions int - ModelId string - ApiUrl string - ApiKey string + TableName string + TablePrimaryKeyCol string + Dimensions int + ModelId string + ApiUrl string + ApiKey string + EmbeddingsPath string + GeneratePath string + BatchSize int } // DefaultConfig returns a Config with default values func DefaultConfig() Config { return Config{ - ModelId: "llama3.2:3b", - ApiKey: util.GetEnvOrDefault("RAG_API_KEY", ""), - TableName: "embeddings", - Dimensions: 3072, // 1536 for OpenAI - ApiUrl: util.GetEnvOrDefault("RAG_API_URL", "http://127.0.0.1:11434"), + ModelId: "llama3.2:3b", + ApiKey: util.GetEnvOrDefault("LLM_API_KEY", ""), + TableName: "embeddings", + TablePrimaryKeyCol: "id", + Dimensions: 3072, // for llama3.2:3b, 1536 for OpenAI + ApiUrl: util.GetEnvOrDefault("LLM_API_URL", "http://127.0.0.1:11434"), + EmbeddingsPath: "/v1/embeddings", + GeneratePath: "/api/generate", + BatchSize: 100, } } // Client handles the RAG operations type Client struct { conn *pgx.Conn - config Config + Config Config + logger *zap.Logger } // NewClient creates a new RAG client -func NewClient(conn *pgx.Conn, config Config) *Client { - return &Client{ +func NewClient(conn *pgx.Conn, config Config, loggers ...*zap.Logger) (*Client, error) { + var logger *zap.Logger + if len(loggers) > 0 && loggers[0] != nil { + logger = loggers[0] + } else { + var err error + logger, err = zap.NewDevelopment() + if err != nil { + return nil, fmt.Errorf("failed to create logger: %w", err) + } + } + + client := &Client{ conn: conn, - config: config, + Config: config, + logger: logger, + } + + if err := client.initialize(); err != nil { + return nil, fmt.Errorf("failed to initialize client: %w", err) } + + return client, nil +} + +func (c *Client) initialize() error { + ctx := context.Background() + + _, err := c.conn.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector") + if err != nil { + return fmt.Errorf("failed to create vector extension: %w", err) + } + + err = pgxvector.RegisterTypes(ctx, c.conn) + if err != nil { + return fmt.Errorf("failed to register vector types: %w", err) + } + + return nil +} + +// ============================================ With retry ============================================ +/* +func (c *Client) ExecuteWithRetry(ctx context.Context, operation func(context.Context) error) error { + var err error + for attempt := 0; attempt < c.config.RetryAttempts; attempt++ { + err = operation(ctx) + if err == nil { + return nil + } + if !isRetryableError(err) { + return err + } + time.Sleep(c.config.RetryBackoff * time.Duration(attempt+1)) + } + return fmt.Errorf("operation failed after %d attempts: %w", c.config.RetryAttempts, err) +} + +func (c *Client) QueryWithMetrics(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) { + if c.config.TracingEnabled { + var span trace.Span + ctx, span = otel.Tracer("rag").Start(ctx, "QueryWithMetrics") + defer span.End() + } + + start := time.Now() + rows, err := c.conn.Query(ctx, query, args...) + duration := time.Since(start) + + if c.config.MetricsEnabled { + // Record metrics (e.g., query duration, success/failure) + // This could integrate with Prometheus, StatsD, etc. + c.logger.Info("query_duration", zap.Duration("duration", duration)) + } + + return rows, err +} + +// isRetryableError determines if the given error is retryable +func isRetryableError(err error) bool { + // Check for specific error types that are retryable + if pgErr, ok := err.(*pgconn.PgError); ok { + // PostgreSQL error codes that are typically retryable + retryableCodes := map[string]bool{ + "40001": true, // serialization_failure + "40P01": true, // deadlock_detected + "53300": true, // too_many_connections + "53400": true, // configuration_limit_exceeded + "08006": true, // connection_failure + "08001": true, // sqlclient_unable_to_establish_sqlconnection + "08004": true, // sqlserver_rejected_establishment_of_sqlconnection + } + return retryableCodes[pgErr.Code] + } + + // Check for network-related errors + if netErr, ok := err.(net.Error); ok { + return netErr.Timeout() + } + + // By default, consider the error as non-retryable + return false } +*/ diff --git a/pkg/rag/client_test.go b/pkg/rag/client_test.go new file mode 100644 index 0000000..c44f992 --- /dev/null +++ b/pkg/rag/client_test.go @@ -0,0 +1,27 @@ +package rag + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" +) + +func TestNewClient(t *testing.T) { + ctx := context.Background() + conn, err := pgx.Connect(ctx, os.Getenv("TEST_DATABASE_URL")) + if err != nil { + t.Fatalf("failed to connect to database: %v", err) + } + + client, err := NewClient(conn, Config{ + TableName: "test_table", + Dimensions: 3072, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + defer client.conn.Close(ctx) +} diff --git a/pkg/rag/doc.go b/pkg/rag/doc.go new file mode 100644 index 0000000..c79d264 --- /dev/null +++ b/pkg/rag/doc.go @@ -0,0 +1,7 @@ +/* +Package rag provides functions to integrate Retrieval Augmented Generation (RAG) +capabilities in PostgreSQL tables using the pgvector extension. It facilitates +operations for adding embeddings on tables from popular language model APIs such +as OpenAI, Ollama, and LMStudio. +*/ +package rag diff --git a/pkg/rag/embedding.go b/pkg/rag/embedding.go index 53efc79..5e5ed01 100644 --- a/pkg/rag/embedding.go +++ b/pkg/rag/embedding.go @@ -9,40 +9,51 @@ import ( "github.com/edgeflare/pgo/pkg/util/httpclient" ) -// EmbeddingRequest is passed as body in FetchEmbeddings +// EmbeddingRequest is the request body for the FetchEmbedding function type EmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` } -type EmbeddingsResponse struct { +// EmbeddingResponse is the response body for the FetchEmbedding function +// https://platform.openai.com/docs/api-reference/embeddings/create +// https://github.com/ollama/ollama/blob/main/docs/api.md#embeddings +type EmbeddingResponse struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` - Model string `json:"model"` - Usage json.RawMessage `json:"usage"` } -// Embeddings fetches embeddings from the API -func (c *Client) FetchEmbeddings(ctx context.Context, input []string) (EmbeddingsResponse, error) { +// FetchEmbedding fetches embeddings from the LLM API +func (c *Client) FetchEmbedding(ctx context.Context, input []string) ([][]float32, error) { + // check if input is empty + if len(input) == 0 { + return [][]float32{}, fmt.Errorf("input is empty") + } + data := &EmbeddingRequest{ Input: input, - Model: c.config.ModelId, + Model: c.Config.ModelId, } headers := map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", c.config.ApiKey)}, + "Authorization": {fmt.Sprintf("Bearer %s", c.Config.ApiKey)}, } - body, err := httpclient.Request(ctx, http.MethodPost, fmt.Sprintf("%s/v1/embeddings", c.config.ApiUrl), data, headers) + body, err := httpclient.Request(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.Config.ApiUrl, c.Config.EmbeddingsPath), data, headers) if err != nil { - return EmbeddingsResponse{}, fmt.Errorf("failed to fetch embeddings: %w", err) + return [][]float32{}, fmt.Errorf("failed to fetch embeddings: %w", err) } - var response EmbeddingsResponse + var response EmbeddingResponse if err := json.Unmarshal(body, &response); err != nil { - return EmbeddingsResponse{}, fmt.Errorf("failed to unmarshal response: %w", err) + return [][]float32{}, fmt.Errorf("failed to unmarshal response: %w", err) + } + + embeddings := make([][]float32, len(response.Data)) + for i, d := range response.Data { + embeddings[i] = d.Embedding } - return response, nil + return embeddings, nil } diff --git a/pkg/rag/embedding_test.go b/pkg/rag/embedding_test.go new file mode 100644 index 0000000..d164612 --- /dev/null +++ b/pkg/rag/embedding_test.go @@ -0,0 +1,52 @@ +package rag + +import ( + "context" + "testing" +) + +func TestFetchEmbedding(t *testing.T) { + // Create a test client with default config + config := DefaultConfig() + + ctx := context.Background() + conn, err := setupTestDatabase(t) + if err != nil { + t.Fatalf("Failed to set up test database: %v", err) + } + defer conn.Close(ctx) + + c, err := NewClient(conn, config) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Test input + input := []string{"Hello", "World"} + + // Call FetchEmbedding + embeddings, err := c.FetchEmbedding(context.Background(), input) + + // Check for errors + if err != nil { + t.Fatalf("FetchEmbedding returned an error: %v", err) + } + + // Check the number of embeddings + if len(embeddings) != len(input) { + t.Errorf("Expected %d embeddings, got %d", len(input), len(embeddings)) + } + + // Check that embeddings are not empty + for i, embedding := range embeddings { + if len(embedding) == 0 { + t.Errorf("Embedding %d is empty", i) + } + } + + // Print the first few values of each embedding for manual inspection + t.Logf("Embeddings (first 5 values):") + for i, embedding := range embeddings { + t.Logf(" Input '%s': %v", input[i], embedding[:5]) + } +} diff --git a/pkg/rag/generate.go b/pkg/rag/generate.go index 95ac820..27c5174 100644 --- a/pkg/rag/generate.go +++ b/pkg/rag/generate.go @@ -63,46 +63,19 @@ type GenerateRequest struct { func (c *Client) Generate(ctx context.Context, prompt string) ([]byte, error) { data := GenerateRequest{ Prompt: prompt, - Model: c.config.ModelId, + Model: c.Config.ModelId, Stream: false, Format: "json", } headers := map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", c.config.ApiKey)}, + "Authorization": {fmt.Sprintf("Bearer %s", c.Config.ApiKey)}, } - body, err := httpclient.Request(ctx, http.MethodPost, fmt.Sprintf("%s/api/generate", c.config.ApiUrl), data, headers, time.Minute*1) + body, err := httpclient.Request(ctx, http.MethodPost, fmt.Sprintf("%s%s", c.Config.ApiUrl, c.Config.GeneratePath), data, headers, time.Minute*1) if err != nil { return nil, fmt.Errorf("API request failed: %w", err) } return body, nil } - -// GenerateWithRetrieval sends a generation request to the API after retrieving relevant information -// based on the provided prompt. -func (c *Client) GenerateWithRetrieval(ctx context.Context, prompt string, limit int) ([]byte, error) { - // Retrieve embeddings for user promt - retrievedEmbeddings, err := c.Retrieve(ctx, prompt, limit) - if err != nil { - return nil, fmt.Errorf("failed to retrieve embeddings: %w", err) - } - - // Construct the new prompt - var contextBuilder string - for _, embedding := range retrievedEmbeddings { - contextBuilder += fmt.Sprintf("Content: %s\n", embedding.Content) - } - - // Combine the context with the original prompt - newPrompt := fmt.Sprintf("Here are some relevant pieces of information:\n%s\n\nUsing this context, %s", contextBuilder, prompt) - - // Send the generation request - response, err := c.Generate(ctx, newPrompt) - if err != nil { - return nil, fmt.Errorf("failed to generate response: %w", err) - } - - return response, nil -} diff --git a/pkg/rag/pgvector.go b/pkg/rag/pgvector.go new file mode 100644 index 0000000..6ce84df --- /dev/null +++ b/pkg/rag/pgvector.go @@ -0,0 +1,330 @@ +package rag + +import ( + "context" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pgvector/pgvector-go" + "go.uber.org/zap" +) + +// Embedding represents a row with columns id, content, and embedding. +type Embedding struct { + // PK is the primary key column name + PK interface{} + // Content is what vector embeddings are created from + Content string + // Embedding is the vector embedding for the content + Embedding pgvector.Vector +} + +// CreateEmbedding inserts embeddings into the table based on the contentSelectQuery +// If no contentSelectQuery is supplied, it assumes the content column of the table is already populated, and use content column directly +// If an empty contentSelectQuery is provided, it queries the table rows and constructs the content column as col1_name:col1_value,col2_name:col2_value,... +// For a non-empty contentSelectQuery, it expects the query to return two columns: primary key and content. +func (c *Client) CreateEmbedding(ctx context.Context, contentSelectQuery ...string) error { + var query string + + // Ensure the table exists and has the required columns + if err := c.ensureTableConfig(ctx); err != nil { + return err + } + + // Check the query conditions + if len(contentSelectQuery) == 0 { + // Case 1: No query supplied, assume content is already populated + query = fmt.Sprintf("SELECT %s, content FROM %s", c.Config.TablePrimaryKeyCol, c.Config.TableName) + } else if contentSelectQuery[0] == "" { + // Case 2: Empty string query, construct content column + schema, tableName := splitSchemaTableName(c.Config.TableName) + c.logger.Info("Query is empty, using table columns as content", zap.String("table", c.Config.TableName)) + columns, err := c.queryAndFilterColumnNames(ctx, schema, tableName, []string{"embedding", "content"}) + if err != nil { + return err + } + query = fmt.Sprintf("SELECT %s FROM %s", strings.Join(columns, ", "), c.Config.TableName) + } else { + // Case 3: Non-empty query + query = contentSelectQuery[0] + } + + // queryAndProcessEmbeddingContents to process contents and ids + rows, err := c.queryAndProcessEmbeddingContents(ctx, query) + if err != nil { + return err + } + c.logger.Debug("contents", zap.Any("contents", rows)) + + contents := []string{} + for _, embedding := range rows { + contents = append(contents, embedding.Content) + } + + embeddings, err := c.FetchEmbedding(ctx, contents) + if err != nil { + return fmt.Errorf("failed to fetch embeddings: %w", err) + } + c.logger.Debug("embeddings", zap.Int("embeddings", len(embeddings))) + + // Ensure the lengths match + if len(contents) != len(embeddings) { + return fmt.Errorf("mismatch between contents and embeddings length: %d vs %d", len(contents), len(embeddings)) + } + + // update the embeddings into the database + for i := range contents { // Use index to access both contents and ids + embedding := embeddings[i] + pk := rows[i].PK + + c.logger.Debug("embedding", zap.Any("id", pk), zap.String("content", contents[i]), zap.Any("embedding[0]", embedding[0])) + + // Call the helper function to update the embedding + if err := c.updateEmbedding(ctx, embedding, contents[i], pk); err != nil { + return err + } + } + + return nil +} + +// Retrieve retrieves the most similar rows to the input +func (c *Client) Retrieve(ctx context.Context, input string, limit int) ([]Embedding, error) { + // Step 1: Fetch the embedding for the input text + embedding, err := c.FetchEmbedding(ctx, []string{input}) + if err != nil { + return nil, fmt.Errorf("failed to fetch embedding for input: %w", err) + } + + // Step 2: Query the database using the fetched embedding + queryStr := fmt.Sprintf( + "SELECT id, content, embedding FROM %s ORDER BY embedding <=> $1 LIMIT $2", + c.Config.TableName, + ) + + // Execute the query with the embedding + rows, err := c.conn.Query(ctx, queryStr, pgvector.NewVector(embedding[0]), limit) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + defer rows.Close() + + // Collect the results + var results []Embedding + for rows.Next() { + var embedding Embedding + if err := rows.Scan(&embedding.PK, &embedding.Content, &embedding.Embedding); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + results = append(results, embedding) + } + + return results, nil +} + +// ensureTableConfig ensures the table exists and has the required columns +func (c *Client) ensureTableConfig(ctx context.Context) error { + tableExists, err := c.tableExists(ctx, c.Config.TableName) + if err != nil { + return fmt.Errorf("failed to check if table exists: %w", err) + } + + if !tableExists { + err = c.createTable(ctx, c.Config.TableName) + if err != nil { + // Check if the error is because the table already exists + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42P07" { + // Table already exists, so we can proceed + tableExists = true + } else { + return fmt.Errorf("failed to create table: %w", err) + } + } + } + + if tableExists { + err = c.ensureRequiredColumnsExist(ctx, c.Config.TableName) + if err != nil { + return fmt.Errorf("failed to ensure embedding column: %w", err) + } + } + + return nil +} + +// ensureRequiredColumnsExist ensures the table has the required columns +func (c *Client) ensureRequiredColumnsExist(ctx context.Context, tableName string) error { + // Check for 'embedding' column + var embeddingColumnExists bool + err := c.conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name='embedding')", tableName).Scan(&embeddingColumnExists) + if err != nil { + return fmt.Errorf("failed to check for embedding column: %w", err) + } + + // Check for 'content' column + var contentColumnExists bool + err = c.conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name='content')", tableName).Scan(&contentColumnExists) + if err != nil { + return fmt.Errorf("failed to check for content column: %w", err) + } + + // Check for primary key column + var pkColumnExists bool + err = c.conn.QueryRow(ctx, fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name='%s')", c.Config.TablePrimaryKeyCol), tableName).Scan(&pkColumnExists) + if err != nil { + return fmt.Errorf("failed to check for id column: %w", err) + } + + // Add 'embedding' column if it doesn't exist + if !embeddingColumnExists { + query := fmt.Sprintf("ALTER TABLE %s ADD COLUMN embedding VECTOR(%d)", tableName, c.Config.Dimensions) + _, err = c.conn.Exec(ctx, query) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42701" { + } else { + return fmt.Errorf("failed to add embedding column: %w", err) + } + } + } + + // Add 'content' column if it doesn't exist + if !contentColumnExists { + query := fmt.Sprintf("ALTER TABLE %s ADD COLUMN content TEXT", tableName) + _, err = c.conn.Exec(ctx, query) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42701" { + } else { + return fmt.Errorf("failed to add content column: %w", err) + } + } + } + + // Add primary key column if it doesn't exist + if !pkColumnExists { + query := fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY", tableName, c.Config.TablePrimaryKeyCol) + _, err = c.conn.Exec(ctx, query) + if err != nil { + if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42701" { + } else { + return fmt.Errorf("failed to add content column: %w", err) + } + } + } + + return nil +} + +// tableExists checks if the table exists +func (c *Client) tableExists(ctx context.Context, tableName string) (bool, error) { + var exists bool + err := c.conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name=$1)", tableName).Scan(&exists) + return exists, err +} + +// createTable creates the table if it doesn't exist +func (c *Client) createTable(ctx context.Context, tableName string) error { + query := fmt.Sprintf(` + CREATE TABLE %s ( + id BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + content TEXT, + embedding VECTOR(%d) + )`, tableName, c.Config.Dimensions) + + _, err := c.conn.Exec(ctx, query) + return err +} + +// formatRowValues formats the row values into a string +func formatRowValues(fields []pgconn.FieldDescription, values []interface{}) string { + var pairs []string + for i, field := range fields { + pairs = append(pairs, fmt.Sprintf("%s:%v", field.Name, values[i])) + } + return strings.Join(pairs, ",") +} + +func splitSchemaTableName(tableName string) (string, string) { + parts := strings.Split(tableName, ".") + if len(parts) == 2 { + return parts[0], parts[1] + } + return "public", tableName +} + +func (c *Client) updateEmbedding(ctx context.Context, embedding []float32, content string, id interface{}) error { + // Prepare update query + query := fmt.Sprintf(` + UPDATE %s + SET embedding = $1, content = $2 + WHERE %s = $3 + `, c.Config.TableName, c.Config.TablePrimaryKeyCol) + + // Execute the query + _, err := c.conn.Exec(ctx, query, pgvector.NewVector(embedding), content, id) + if err != nil { + return fmt.Errorf("failed to update embedding: %w", err) + } + return nil +} + +// queryAndProcessEmbeddingContents queries the database and processes the rows to populate contents and ids. +func (c *Client) queryAndProcessEmbeddingContents(ctx context.Context, selectQuery string) ([]Embedding, error) { + var contents []Embedding + + rows, err := c.conn.Query(ctx, selectQuery) + if err != nil { + return nil, fmt.Errorf("failed to query database: %w", err) + } + defer rows.Close() + + for rows.Next() { + values, err := rows.Values() + if err != nil { + return nil, fmt.Errorf("failed to get row values: %w", err) + } + + content := formatRowValues(rows.FieldDescriptions(), values) + contents = append(contents, Embedding{ + PK: values[0], + Content: content, + // Embedding: values[2].(pgvector.Vector), + }) + } + + return contents, nil +} + +func (c *Client) queryAndFilterColumnNames(ctx context.Context, schema, tableName string, toRemove []string) ([]string, error) { + columnsQuery := fmt.Sprintf(` + SELECT column_name + FROM information_schema.columns + WHERE table_schema = '%s' + AND table_name = '%s' + `, schema, tableName) + + rows, err := c.conn.Query(ctx, columnsQuery) + if err != nil { + return nil, fmt.Errorf("failed to get column names: %w", err) + } + defer rows.Close() + + // Create a map for columns to remove + removeMap := make(map[string]struct{}) + for _, col := range toRemove { + removeMap[col] = struct{}{} + } + + var filtered []string + for rows.Next() { + var column string + if err := rows.Scan(&column); err != nil { + return nil, fmt.Errorf("failed to scan column name: %w", err) + } + // Only append if the column is not in the remove map + if _, found := removeMap[column]; !found { + filtered = append(filtered, column) + } + } + return filtered, nil +} diff --git a/pkg/rag/pgvector_test.go b/pkg/rag/pgvector_test.go new file mode 100644 index 0000000..5d6fb48 --- /dev/null +++ b/pkg/rag/pgvector_test.go @@ -0,0 +1,104 @@ +package rag + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/jackc/pgx/v5" +) + +func TestEnsureTableConfig(t *testing.T) { + ctx := context.Background() + conn, err := setupTestDatabase(t) + if err != nil { + t.Fatalf("Failed to set up test database: %v", err) + } + defer conn.Close(ctx) + + c, err := NewClient(conn, Config{ + TableName: "test_table", + Dimensions: 3072, + TablePrimaryKeyCol: "id", + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Ensure the test table is dropped before and after the test + dropTable := func() { + _, err := conn.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", c.Config.TableName)) + if err != nil { + t.Fatalf("Failed to drop test table: %v", err) + } + } + dropTable() + defer dropTable() + + // Test case 1: Ensure column is added when the table doesn't exist + t.Run("AddColumnWhenTableNotExists", func(t *testing.T) { + err = c.ensureTableConfig(ctx) + if err != nil { + t.Fatalf("Failed to ensure embedding column: %v", err) + } + + // Verify that the table was created with the correct columns + var tableExists bool + err = conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name=$1)", c.Config.TableName).Scan(&tableExists) + if err != nil { + t.Fatalf("Failed to check if table exists: %v", err) + } + if !tableExists { + t.Errorf("Table was not created") + } + + // Verify that the table has the correct primary key column + var primaryKeyCol string + err = conn.QueryRow(ctx, "SELECT column_name FROM information_schema.columns WHERE table_name=$1 AND column_name='id'", c.Config.TableName).Scan(&primaryKeyCol) + if err != nil { + t.Fatalf("Failed to check if primary key column exists: %v", err) + } + if primaryKeyCol != c.Config.TablePrimaryKeyCol { + t.Errorf("Primary key column was not added correctly") + } + + // Verify that the embedding column exists + var columnExists bool + err = conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name='embedding')", c.Config.TableName).Scan(&columnExists) + if err != nil { + t.Fatalf("Failed to check if column exists: %v", err) + } + if !columnExists { + t.Errorf("Embedding column was not added") + } + + // Verify that the table has the correct content column + var contentCol string + err = conn.QueryRow(ctx, "SELECT column_name FROM information_schema.columns WHERE table_name=$1 AND column_name='content'", c.Config.TableName).Scan(&contentCol) + if err != nil { + t.Fatalf("Failed to check if content column exists: %v", err) + } + if contentCol != "content" { + t.Errorf("Content column was not added correctly") + } + }) + + // Test case 2: Ensure no error when column already exists + t.Run("NoErrorWhenColumnExists", func(t *testing.T) { + err = c.ensureTableConfig(ctx) + if err != nil { + t.Fatalf("Unexpected error when ensuring existing column: %v", err) + } + }) +} + +func setupTestDatabase(t *testing.T) (*pgx.Conn, error) { + ctx := context.Background() + conn, err := pgx.Connect(ctx, os.Getenv("TEST_DATABASE_URL")) + if err != nil { + t.Fatalf("failed to connect to database: %v", err) + } + + return conn, nil +} diff --git a/pkg/rag/vector.go b/pkg/rag/vector.go deleted file mode 100644 index 8b4352a..0000000 --- a/pkg/rag/vector.go +++ /dev/null @@ -1,186 +0,0 @@ -package rag - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/jackc/pgx/v5" - "github.com/pgvector/pgvector-go" - pgxvector "github.com/pgvector/pgvector-go/pgx" -) - -// VectorEmbedding represents a single embedding with metadata -type VectorEmbedding struct { - ID int64 `json:"id"` - Tags *[]string `json:"tags,omitempty"` - Metadata *json.RawMessage `json:"metadata,omitempty"` - Content string `json:"content"` - Embedding pgvector.Vector `json:"embedding"` -} - -// CreateTable creates the necessary database extensions and table -func (c *Client) CreateTable(ctx context.Context) error { - _, err := c.conn.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector") - if err != nil { - return fmt.Errorf("failed to create vector extension: %w", err) - } - - err = pgxvector.RegisterTypes(ctx, c.conn) - if err != nil { - return fmt.Errorf("failed to register vector types: %w", err) - } - - query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( - id BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, - tags TEXT[], - metadata JSONB, - content TEXT, - embedding VECTOR(%v) - )`, c.config.TableName, c.config.Dimensions) - - _, err = c.conn.Exec(ctx, query) - if err != nil { - return fmt.Errorf("failed to create table: %w", err) - } - - return nil -} - -// InsertEmbeddings takes a slice of VectorEmbedding and inserts them into the database -// using the provided configuration -func (c *Client) InsertEmbeddings(ctx context.Context, embeddings []VectorEmbedding, batchSize ...int) error { - if len(embeddings) == 0 { - return nil - } - - // Ensure table exists - if err := c.CreateTable(ctx); err != nil { - return fmt.Errorf("failed to set up database: %w", err) - } - - // Prepare the query - query := fmt.Sprintf(` - INSERT INTO %s (content, embedding, tags, metadata) - VALUES ($1, $2, $3, $4)`, - c.config.TableName, - ) - - // Process embeddings in batches - for i := 0; i < len(embeddings); i += batchSize[0] { - end := i + batchSize[0] - if end > len(embeddings) { - end = len(embeddings) - } - - batch := embeddings[i:end] - - // Begin transaction for this batch - tx, err := c.conn.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - - // Create a new batch - pgxBatch := &pgx.Batch{} - for _, embedding := range batch { - pgxBatch.Queue(query, - embedding.Content, - embedding.Embedding, - embedding.Tags, - embedding.Metadata, - ) - } - - // Execute batch - batchResults := tx.SendBatch(ctx, pgxBatch) - - // Process all results before closing - for j := 0; j < pgxBatch.Len(); j++ { - _, err := batchResults.Exec() - if err != nil { - batchResults.Close() - tx.Rollback(ctx) - return fmt.Errorf("failed to insert embedding at index %d: %w", i+j, err) - } - } - - // Close batch results before committing - if err := batchResults.Close(); err != nil { - tx.Rollback(ctx) - return fmt.Errorf("failed to close batch results: %w", err) - } - - // Commit transaction - if err := tx.Commit(ctx); err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) - } - } - - return nil -} - -// Retrieve returns []VectorEmbedding similar to a query using the RAG API and database -func (c *Client) Retrieve(ctx context.Context, query string, limit int) ([]VectorEmbedding, error) { - // Get embeddings for the query - embeddingResponse, err := c.FetchEmbeddings(ctx, []string{query}) - if err != nil { - return nil, fmt.Errorf("failed to fetch embeddings: %w", err) - } - - if len(embeddingResponse.Data) == 0 { - return nil, fmt.Errorf("no embeddings returned from API") - } - - // Construct the query using the configured table name - queryStr := fmt.Sprintf( - "SELECT id, content, tags, metadata, embedding FROM %s ORDER BY embedding <=> $1 LIMIT $2", - c.config.TableName, - ) - - // Execute the query - rows, err := c.conn.Query(ctx, queryStr, - pgvector.NewVector(embeddingResponse.Data[0].Embedding), - limit, - ) - if err != nil { - return nil, fmt.Errorf("failed to execute similarity search: %w", err) - } - defer rows.Close() - - // Process the results - var results []VectorEmbedding - for rows.Next() { - var doc VectorEmbedding - if err := rows.Scan(&doc.ID, &doc.Content, &doc.Tags, &doc.Metadata, &doc.Embedding); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) - } - results = append(results, doc) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating over rows: %w", err) - } - - return results, nil -} - -// Helper function to create embeddings from API response -func ToVectorEmbedding(texts []string, response EmbeddingsResponse, tags []string, metadata json.RawMessage) ([]VectorEmbedding, error) { - if len(texts) != len(response.Data) { - return nil, fmt.Errorf("mismatch between texts (%d) and embeddings (%d)", len(texts), len(response.Data)) - } - - embeddings := make([]VectorEmbedding, len(texts)) - for i, text := range texts { - vector := pgvector.NewVector(response.Data[i].Embedding) - embeddings[i] = VectorEmbedding{ - Content: text, - Embedding: vector, - Tags: &tags, - Metadata: &metadata, - } - } - - return embeddings, nil -}