Skip to content

Commit

Permalink
add pkg/rag tests
Browse files Browse the repository at this point in the history
Signed-off-by: hmoazzem <moazzem@edgeflare.io>
  • Loading branch information
hmoazzem committed Oct 21, 2024
1 parent 31525ee commit 8ca611e
Show file tree
Hide file tree
Showing 12 changed files with 691 additions and 448 deletions.
133 changes: 0 additions & 133 deletions examples/pgvector101/hello-world.go

This file was deleted.

3 changes: 0 additions & 3 deletions examples/rag101/README.md

This file was deleted.

87 changes: 18 additions & 69 deletions examples/rag101/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,49 @@ package main

import (
"context"
"encoding/json"
"fmt"
"log"

"github.com/edgeflare/pgo/pkg/rag"
"github.com/edgeflare/pgo/pkg/util"
"github.com/jackc/pgx/v5"
)

func main() {
// insertEmbeddings()
generateWithRetrieval()
}

func generateWithRetrieval() {
// Database connection parameters
connConfig, err := pgx.ParseConfig("postgres://postgres:secret@localhost:5431/testdb")
if err != nil {
log.Fatalf("unable to parse connection config: %v", err)
}

ctx := context.Background()

// Create a new connection to the database
conn, err := pgx.ConnectConfig(ctx, connConfig)
conn, err := pgx.Connect(ctx, util.GetEnvOrDefault("DATABASE_URL", "postgres://postgres:secret@localhost:5432/postgres"))
if err != nil {
log.Fatalf("unable to connect to database: %v", err)
log.Fatalf("Failed to connect to database: %v", err)
}
defer conn.Close(ctx)

// Create a new RAG client with default configuration
client := rag.NewClient(conn, rag.DefaultConfig())
// Create a new RAG client
client, err := rag.NewClient(conn, rag.DefaultConfig())
client.Config.TableName = "example_table"

response, err := client.Generate(ctx, "what is the cat doing?")
if err != nil {
log.Fatalln(err)
log.Fatalf("Failed to create RAG client: %v", err)
}

fmt.Printf("%s", response)
}

func insertEmbeddings() {
// Database connection parameters
connConfig, err := pgx.ParseConfig("postgres://postgres:secret@localhost:5431/testdb")
err = client.CreateEmbedding(ctx, "")
if err != nil {
log.Fatalf("unable to parse connection config: %v", err)
log.Fatalf("Failed to create embeddings: %v", err)
}

ctx := context.Background()
fmt.Println("Embeddings have been successfully created.")

// Create a new connection to the database
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
log.Fatalf("unable to connect to database: %v", err)
}
defer conn.Close(ctx)
// retrieval example
input := "example input text"
limit := 2

// Create a new RAG client with default configuration
client := rag.NewClient(conn, rag.DefaultConfig())

// Sample input texts
input := []string{
"The dog is barking",
"The cat is purring",
"The bear is growling",
// Add more examples here
}

// Fetch embeddings from the API
embeddingsResponse, err := client.FetchEmbeddings(ctx, input)
if err != nil {
log.Fatalf("failed to fetch embeddings: %v", err)
}

// Create metadata for all embeddings
metadata := json.RawMessage(`{
"source": "sample",
"batch": "animal_sounds",
"timestamp": "2024-10-19"
}`)

// Create tags for all embeddings
tags := []string{"animal", "sound", "example"}

// Create vector embeddings using the helper function
embeddings, err := rag.ToVectorEmbedding(input, embeddingsResponse, tags, metadata)
results, err := client.Retrieve(ctx, input, limit)
if err != nil {
log.Fatalf("failed to create embeddings: %v", err)
log.Fatalf("Failed to retrieve content: %v", err)
}

// Insert embeddings with progress tracking
err = client.InsertEmbeddings(ctx, embeddings, 1)
if err != nil {
log.Fatalf("failed to insert embeddings: %v", err)
// Print the retrieved results
for _, r := range results {
fmt.Printf("ID: %v\nContent: %s\nEmbedding: %v\n", r.PK, r.Content, r.Embedding.Slice()[0])
}

fmt.Println("\nAll embeddings inserted successfully!")
}
Loading

0 comments on commit 8ca611e

Please sign in to comment.