Skip to content

Commit

Permalink
Fixes in driver setup code
Browse files Browse the repository at this point in the history
  • Loading branch information
ddebowczyk committed Oct 7, 2024
1 parent a3fcba4 commit d769a42
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 38 deletions.
3 changes: 1 addition & 2 deletions config/embed.php
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@
'defaultModel' => 'nomic-embed-text',
'defaultDimensions' => 1024,
'maxInputs' => 512,
'connectionTimeout' => 5,
'requestTimeout' => 30,
'httpClient' => 'guzzle-ollama',
],
'openai' => [
'providerType' => LLMProviderType::OpenAI->value,
Expand Down
7 changes: 6 additions & 1 deletion config/http.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
],
'symfony' => [
'httpClientType' => HttpClientType::Symfony->value,
'connectTimeout' => 90, // Symfony HttpClient does not allow to set connect timeout, set it to request timeout
'requestTimeout' => 30,
'idleTimeout' => -1,
],
'guzzle-ollama' => [
'httpClientType' => HttpClientType::Guzzle->value,
'connectTimeout' => 5,
'requestTimeout' => 90,
],
]
Expand Down
3 changes: 1 addition & 2 deletions config/llm.php
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@
'endpoint' => '/chat/completions',
'defaultModel' => 'llama3.2:3b', //'gemma2:2b',
'defaultMaxTokens' => 1024,
'connectTimeout' => 30,
'requestTimeout' => 60,
'httpClient' => 'guzzle-ollama',
],
'openai' => [
'providerType' => LLMProviderType::OpenAI->value,
Expand Down
20 changes: 10 additions & 10 deletions src/Extras/Embeddings/Data/EmbeddingsConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
class EmbeddingsConfig
{
public function __construct(
public string $apiUrl = '',
public string $apiKey = '',
public string $endpoint = '',
public string $model = '',
public int $dimensions = 0,
public int $maxInputs = 0,
public array $metadata = [],
public string $apiUrl = '',
public string $apiKey = '',
public string $endpoint = '',
public string $model = '',
public int $dimensions = 0,
public int $maxInputs = 0,
public array $metadata = [],
public string $httpClient = '',
public ?LLMProviderType $providerType = null,
) {}

Expand All @@ -31,9 +32,8 @@ public static function load(string $connection) : EmbeddingsConfig {
dimensions: Settings::get('embed', "connections.$connection.defaultDimensions", 0),
maxInputs: Settings::get('embed', "connections.$connection.maxInputs", 1),
metadata: Settings::get('embed', "connections.$connection.metadata", []),
providerType: LLMProviderType::from(
Settings::get('embed', "connections.$connection.providerType")
),
httpClient: Settings::get('embed', "connections.$connection.httpClient", ''),
providerType: LLMProviderType::from(Settings::get('embed', "connections.$connection.providerType")),
);
}
}
19 changes: 15 additions & 4 deletions src/Extras/Embeddings/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Cognesy\Instructor\Extras\Embeddings;

use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Extras\Embeddings\Contracts\CanVectorize;
use Cognesy\Instructor\Extras\Embeddings\Data\EmbeddingsConfig;
use Cognesy\Instructor\Extras\Embeddings\Drivers\AzureOpenAIDriver;
Expand All @@ -19,14 +20,24 @@ class Embeddings
{
use Traits\HasFinders;

protected EventDispatcher $events;
protected EmbeddingsConfig $config;
protected CanHandleHttp $httpClient;
protected CanVectorize $driver;

public function __construct() {
$this->httpClient = HttpClient::make();
$this->config = EmbeddingsConfig::load(Settings::get('embed', "defaultConnection"));
$this->driver = $this->getDriver($this->config, $this->httpClient);
public function __construct(
string $connection = '',
EmbeddingsConfig $config = null,
CanHandleHttp $httpClient = null,
CanVectorize $driver = null,
EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();
$this->config = $config ?? EmbeddingsConfig::load($connection
?: Settings::get('embed', "defaultConnection")
);
$this->httpClient = $httpClient ?? HttpClient::make($this->config->httpClient);
$this->driver = $driver ?? $this->getDriver($this->config, $this->httpClient);
}

// PUBLIC ///////////////////////////////////////////////////
Expand Down
5 changes: 3 additions & 2 deletions src/Extras/Http/Data/HttpClientConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ public function __construct(
public HttpClientType $httpClientType = HttpClientType::Guzzle,
public int $connectTimeout = 3,
public int $requestTimeout = 30,
) {
}
public int $idleTimeout = 0,
) {}

public static function load(string $client) : HttpClientConfig {
if (!Settings::has('http', "clients.$client")) {
Expand All @@ -23,6 +23,7 @@ public static function load(string $client) : HttpClientConfig {
httpClientType: HttpClientType::from(Settings::get('http', "clients.$client.httpClientType")),
connectTimeout: Settings::get(group: "http", key: "clients.$client.connectTimeout", default: 30),
requestTimeout: Settings::get("http", "clients.$client.requestTimeout", 3),
idleTimeout: Settings::get(group: "http", key: "clients.$client.idleTimeout", default: 0),
);
}
}
4 changes: 2 additions & 2 deletions src/Extras/Http/Drivers/SymfonyDriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public function handle(
options: [
'headers' => $headers,
'body' => is_array($body) ? json_encode($body) : $body,
'timeout' => $this->config->requestTimeout ?? 5,
'max_duration' => $this->config->connectTimeout ?? 30,
'timeout' => $this->config->idleTimeout ?? 0,
'max_duration' => $this->config->requestTimeout ?? 30,
'buffer' => !$streaming,
]
);
Expand Down
6 changes: 2 additions & 4 deletions src/Extras/Http/HttpClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
use Cognesy\Instructor\Extras\Http\Data\HttpClientConfig;
use Cognesy\Instructor\Extras\Http\Drivers\GuzzleDriver;
use Cognesy\Instructor\Extras\Http\Drivers\SymfonyDriver;
use Cognesy\Instructor\Extras\Http\Drivers\SymfonyNyholmDriver;
use Cognesy\Instructor\Utils\Settings;
use InvalidArgumentException;

Expand All @@ -19,13 +18,12 @@ class HttpClient

public function __construct(string $client = '', EventDispatcher $events = null) {
$this->events = $events ?? new EventDispatcher();

$config = HttpClientConfig::load($client ?: Settings::get('http', "defaultClient"));
$this->driver = $this->makeDriver($config);
}

public static function make(string $client = '') : CanHandleHttp {
return (new self($client))->get();
public static function make(string $client = '', ?EventDispatcher $events = null) : CanHandleHttp {
return (new self($client, $events))->get();
}

public function withClient(string $name) : self {
Expand Down
2 changes: 2 additions & 0 deletions src/Extras/LLM/Data/LLMConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public function __construct(
public array $metadata = [],
public string $model = '',
public int $maxTokens = 1024,
public string $httpClient = '',
public LLMProviderType $providerType = LLMProviderType::OpenAICompatible,
) {}

Expand All @@ -29,6 +30,7 @@ public static function load(string $connection) : LLMConfig {
metadata: Settings::get('llm', "connections.$connection.metadata", []),
model: Settings::get('llm', "connections.$connection.defaultModel", ''),
maxTokens: Settings::get('llm', "connections.$connection.defaultMaxTokens", 1024),
httpClient: Settings::get('llm', "connections.$connection.httpClient", ''),
providerType: LLMProviderType::from(Settings::get('llm', "connections.$connection.providerType")),
);
}
Expand Down
3 changes: 3 additions & 0 deletions src/Extras/LLM/Drivers/AnthropicDriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use Cognesy\Instructor\Data\Messages\Messages;
use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleHttp;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleResponse;
use Cognesy\Instructor\Extras\Http\HttpClient;
Expand All @@ -21,7 +22,9 @@ class AnthropicDriver implements CanHandleInference
public function __construct(
protected LLMConfig $config,
protected ?CanHandleHttp $httpClient = null,
protected ?EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();
$this->httpClient = $httpClient ?? HttpClient::make();
}

Expand Down
3 changes: 3 additions & 0 deletions src/Extras/LLM/Drivers/CohereV1Driver.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use Cognesy\Instructor\Data\Messages\Messages;
use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleHttp;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleResponse;
use Cognesy\Instructor\Extras\Http\HttpClient;
Expand All @@ -20,7 +21,9 @@ class CohereV1Driver implements CanHandleInference
public function __construct(
protected LLMConfig $config,
protected ?CanHandleHttp $httpClient = null,
protected ?EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();
$this->httpClient = $httpClient ?? HttpClient::make();
}

Expand Down
3 changes: 3 additions & 0 deletions src/Extras/LLM/Drivers/GeminiDriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use Cognesy\Instructor\Data\Messages\Messages;
use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleHttp;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleResponse;
use Cognesy\Instructor\Extras\Http\HttpClient;
Expand All @@ -22,7 +23,9 @@ class GeminiDriver implements CanHandleInference
public function __construct(
protected LLMConfig $config,
protected ?CanHandleHttp $httpClient = null,
protected ?EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();
$this->httpClient = $httpClient ?? HttpClient::make();
}

Expand Down
3 changes: 3 additions & 0 deletions src/Extras/LLM/Drivers/OpenAIDriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace Cognesy\Instructor\Extras\LLM\Drivers;

use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleHttp;
use Cognesy\Instructor\Extras\Http\Contracts\CanHandleResponse;
use Cognesy\Instructor\Extras\Http\HttpClient;
Expand All @@ -19,7 +20,9 @@ class OpenAIDriver implements CanHandleInference
public function __construct(
protected LLMConfig $config,
protected ?CanHandleHttp $httpClient = null,
protected ?EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();
$this->httpClient = $httpClient ?? HttpClient::make();
}

Expand Down
21 changes: 10 additions & 11 deletions src/Extras/LLM/Inference.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class Inference
{
protected LLMConfig $config;

protected EventDispatcher $events;
protected CanHandleInference $driver;
protected CanHandleHttp $httpClient;
protected EventDispatcher $events;
protected CachedContext $cachedContext;

public function __construct(
Expand All @@ -39,11 +39,10 @@ public function __construct(
EventDispatcher $events = null,
) {
$this->events = $events ?? new EventDispatcher();

$this->httpClient = $httpClient ?? HttpClient::make();
$this->config = $config ?? LLMConfig::load(connection: $connection
?: Settings::get('llm', "defaultConnection")
);
$this->httpClient = $httpClient ?? HttpClient::make($this->config->httpClient);
$this->driver = $driver ?? $this->makeDriver($this->config, $this->httpClient);
}

Expand Down Expand Up @@ -135,20 +134,20 @@ public function create(

protected function makeDriver(LLMConfig $config, CanHandleHttp $httpClient): CanHandleInference {
return match ($config->providerType) {
LLMProviderType::Anthropic => new AnthropicDriver($config, $httpClient),
LLMProviderType::Azure => new AzureOpenAIDriver($config, $httpClient),
LLMProviderType::CohereV1 => new CohereV1Driver($config, $httpClient),
LLMProviderType::CohereV2 => new CohereV2Driver($config, $httpClient),
LLMProviderType::Gemini => new GeminiDriver($config, $httpClient),
LLMProviderType::Mistral => new MistralDriver($config, $httpClient),
LLMProviderType::OpenAI => new OpenAIDriver($config, $httpClient),
LLMProviderType::Anthropic => new AnthropicDriver($config, $httpClient, $this->events),
LLMProviderType::Azure => new AzureOpenAIDriver($config, $httpClient, $this->events),
LLMProviderType::CohereV1 => new CohereV1Driver($config, $httpClient, $this->events),
LLMProviderType::CohereV2 => new CohereV2Driver($config, $httpClient, $this->events),
LLMProviderType::Gemini => new GeminiDriver($config, $httpClient, $this->events),
LLMProviderType::Mistral => new MistralDriver($config, $httpClient, $this->events),
LLMProviderType::OpenAI => new OpenAIDriver($config, $httpClient, $this->events),
LLMProviderType::Fireworks,
LLMProviderType::Groq,
LLMProviderType::Ollama,
LLMProviderType::OpenAICompatible,
LLMProviderType::OpenRouter,
LLMProviderType::Together,
=> new OpenAICompatibleDriver($config, $httpClient),
=> new OpenAICompatibleDriver($config, $httpClient, $this->events),
default => throw new InvalidArgumentException("Client not supported: {$config->providerType->value}"),
};
}
Expand Down
4 changes: 4 additions & 0 deletions src/Utils/Settings.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public static function get(string $group, string $key, mixed $default = null) :
self::$settings[$group] = dot(self::loadGroup($group));
}

if ($default === null && !self::has($group, $key)) {
throw new Exception("Settings key not found: $key in group: $group and no default value provided");
}

return self::$settings[$group]->get($key, $default);
}

Expand Down

0 comments on commit d769a42

Please sign in to comment.