diff --git a/docs/cookbook/examples/advanced/context_cache_llm.mdx b/docs/cookbook/examples/advanced/context_cache_llm.mdx index 25711d17..c87a4332 100644 --- a/docs/cookbook/examples/advanced/context_cache_llm.mdx +++ b/docs/cookbook/examples/advanced/context_cache_llm.mdx @@ -45,7 +45,7 @@ $inference = (new Inference)->withConnection('anthropic')->withCachedContext( $response = $inference->create( messages: [['role' => 'user', 'content' => 'CTO of lead gen software vendor']], options: ['max_tokens' => 256], -)->asLLMResponse(); +)->response(); print("----------------------------------------\n"); print("\n# Summary for CTO of lead gen vendor\n"); @@ -60,7 +60,7 @@ assert(Str::contains($response->content, 'lead', false)); $response2 = $inference->create( messages: [['role' => 'user', 'content' => 'CIO of insurance company']], options: ['max_tokens' => 256], -)->asLLMResponse(); +)->response(); print("----------------------------------------\n"); print("\n# Summary for CIO of insurance company\n"); diff --git a/evals/LLMModes/CompareModes.php b/evals/LLMModes/CompareModes.php index ddd0a805..84efda12 100644 --- a/evals/LLMModes/CompareModes.php +++ b/evals/LLMModes/CompareModes.php @@ -66,8 +66,8 @@ private function execute(string $connection, Mode $mode, bool $isStreamed) : Eva notes: $answer, isCorrect: $isCorrect, timeElapsed: $timeElapsed, - inputTokens: $llmResponse->inputTokens, - outputTokens: $llmResponse->outputTokens, + inputTokens: $llmResponse->usage()->inputTokens, + outputTokens: $llmResponse->usage()->outputTokens, ); } catch(Exception $e) { $timeElapsed = microtime(true) - $time; diff --git a/evals/LLMModes/Modes.php b/evals/LLMModes/Modes.php index 57b0f754..1370387b 100644 --- a/evals/LLMModes/Modes.php +++ b/evals/LLMModes/Modes.php @@ -34,7 +34,7 @@ public function callInferenceFor(string|array $query, Mode $mode, string $connec Mode::MdJson => $this->forModeMdJson($query, $connection, $schema, $isStreamed), Mode::Text => $this->forModeText($query, $connection, $isStreamed), }; - return $inferenceResponse->asLLMResponse(); + return $inferenceResponse->response(); } public function forModeTools(string|array $query, string $connection, array $schema, bool $isStreamed) : InferenceResponse { diff --git a/examples/A02_Advanced/ContextCacheLLM/run.php b/examples/A02_Advanced/ContextCacheLLM/run.php index 25711d17..c87a4332 100644 --- a/examples/A02_Advanced/ContextCacheLLM/run.php +++ b/examples/A02_Advanced/ContextCacheLLM/run.php @@ -45,7 +45,7 @@ $response = $inference->create( messages: [['role' => 'user', 'content' => 'CTO of lead gen software vendor']], options: ['max_tokens' => 256], -)->asLLMResponse(); +)->response(); print("----------------------------------------\n"); print("\n# Summary for CTO of lead gen vendor\n"); @@ -60,7 +60,7 @@ $response2 = $inference->create( messages: [['role' => 'user', 'content' => 'CIO of insurance company']], options: ['max_tokens' => 256], -)->asLLMResponse(); +)->response(); print("----------------------------------------\n"); print("\n# Summary for CIO of insurance company\n"); diff --git a/examples/A03_Troubleshooting/TokenUsage/run.php b/examples/A03_Troubleshooting/TokenUsage/run.php index d1f8cd46..b7c704e0 100644 --- a/examples/A03_Troubleshooting/TokenUsage/run.php +++ b/examples/A03_Troubleshooting/TokenUsage/run.php @@ -6,16 +6,11 @@ ## Overview Some use cases require tracking the token usage of the API responses. -Currently, this can be done by listening to the `LLMResponseReceived` -and `PartialLLMResponseReceived` events and summing the token usage -of the responses. +This can be done by getting `Usage` object from Instructor LLM response +object. -Code below demonstrates how it can be implemented using Instructor -event listeners. - -> Note: OpenAI API requires `stream_options` to be set to -> `['include_usage' => true]` to include token usage in the streamed -> responses. +Code below demonstrates how it can be retrieved for both sync and +streamed requests. ## Example @@ -24,10 +19,7 @@ $loader = require 'vendor/autoload.php'; $loader->add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); -use Cognesy\Instructor\Events\Inference\LLMResponseReceived; -use Cognesy\Instructor\Events\Inference\PartialLLMResponseReceived; -use Cognesy\Instructor\Features\LLM\Data\LLMResponse; -use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Instructor; class User { @@ -35,64 +27,40 @@ class User { public string $name; } -class TokenCounter { - public int $input = 0; - public int $output = 0; - public int $cacheCreation = 0; - public int $cacheRead = 0; - - public function add(LLMResponse|PartialLLMResponse $response) { - $this->input += $response->inputTokens; - $this->output += $response->outputTokens; - $this->cacheCreation += $response->cacheCreationTokens; - $this->cacheRead += $response->cacheReadTokens; - } - - public function reset() { - $this->input = 0; - $this->output = 0; - $this->cacheCreation = 0; - $this->cacheRead = 0; - } - - public function print() { - echo "Input tokens: $this->input\n"; - echo "Output tokens: $this->output\n"; - echo "Cache creation tokens: $this->cacheCreation\n"; - echo "Cache read tokens: $this->cacheRead\n"; - } +function printUsage(Usage $usage) : void { + echo "Input tokens: $usage->inputTokens\n"; + echo "Output tokens: $usage->outputTokens\n"; + echo "Cache creation tokens: $usage->cacheWriteTokens\n"; + echo "Cache read tokens: $usage->cacheReadTokens\n"; + echo "Reasoning tokens: $usage->reasoningTokens\n"; } -$counter = new TokenCounter(); - echo "COUNTING TOKENS FOR SYNC RESPONSE\n"; $text = "Jason is 25 years old and works as an engineer."; -$instructor = (new Instructor) - ->onEvent(LLMResponseReceived::class, fn(LLMResponseReceived $e) => $counter->add($e->llmResponse)) - ->respond( +$response = (new Instructor) + ->request( messages: $text, responseModel: User::class, - ); + )->response(); + echo "\nTEXT: $text\n"; -assert($counter->input > 0); -assert($counter->output > 0); -$counter->print(); +assert($response->usage()->total() > 0); +printUsage($response->usage()); -// Reset the counter -$counter->reset(); echo "\n\nCOUNTING TOKENS FOR STREAMED RESPONSE\n"; $text = "Anna is 19 years old."; -$instructor = (new Instructor) - ->onEvent(PartialLLMResponseReceived::class, fn(PartialLLMResponseReceived $e) => $counter->add($e->partialLLMResponse)) - ->respond( +$stream = (new Instructor) + ->request( messages: $text, responseModel: User::class, - options: ['stream' => true, 'stream_options' => ['include_usage' => true]], -); + options: ['stream' => true], + ) + ->stream(); + +$response = $stream->final(); echo "\nTEXT: $text\n"; -assert($counter->input > 0); -assert($counter->output > 0); -$counter->print(); +assert($stream->usage()->total() > 0); +printUsage($stream->usage()); ?> ``` diff --git a/examples/A05_Extras/LLM/run.php b/examples/A05_Extras/LLM/run.php index 54a4f8b7..e29ce866 100644 --- a/examples/A05_Extras/LLM/run.php +++ b/examples/A05_Extras/LLM/run.php @@ -56,7 +56,8 @@ messages: [['role' => 'user', 'content' => 'Describe capital of Brasil']], options: ['max_tokens' => 128, 'stream' => true] ) - ->stream(); + ->stream() + ->responses(); echo "USER: Describe capital of Brasil\n"; echo "ASSISTANT: "; diff --git a/src/Extras/Embeddings/Drivers/AzureOpenAIDriver.php b/src/Extras/Embeddings/Drivers/AzureOpenAIDriver.php index 628a4c44..2f0f4384 100644 --- a/src/Extras/Embeddings/Drivers/AzureOpenAIDriver.php +++ b/src/Extras/Embeddings/Drivers/AzureOpenAIDriver.php @@ -7,6 +7,7 @@ use Cognesy\Instructor\Extras\Embeddings\EmbeddingsResponse; use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Features\Http\HttpClient; +use Cognesy\Instructor\Features\LLM\Data\Usage; class AzureOpenAIDriver implements CanVectorize { @@ -67,6 +68,12 @@ protected function toResponse(array $response) : EmbeddingsResponse { callback: fn($item) => new Vector(values: $item['embedding'], id: $item['index']), array: $response['data'] ), + usage: $this->makeUsage($response), + ); + } + + protected function makeUsage(array $response): Usage { + return new Usage( inputTokens: $response['usage']['prompt_tokens'] ?? 0, outputTokens: ($response['usage']['total_tokens'] ?? 0) - ($response['usage']['prompt_tokens'] ?? 0), ); diff --git a/src/Extras/Embeddings/Drivers/CohereDriver.php b/src/Extras/Embeddings/Drivers/CohereDriver.php index 0b8c1049..ef02a7a3 100644 --- a/src/Extras/Embeddings/Drivers/CohereDriver.php +++ b/src/Extras/Embeddings/Drivers/CohereDriver.php @@ -8,6 +8,7 @@ use Cognesy\Instructor\Extras\Embeddings\EmbeddingsResponse; use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Features\Http\HttpClient; +use Cognesy\Instructor\Features\LLM\Data\Usage; class CohereDriver implements CanVectorize { @@ -57,8 +58,14 @@ protected function toResponse(array $response) : EmbeddingsResponse { } return new EmbeddingsResponse( vectors: $vectors, + usage: $this->makeUsage($response), + ); + } + + private function makeUsage(array $response) : Usage { + return new Usage( inputTokens: $response['meta']['billed_units']['input_tokens'] ?? 0, - outputTokens: 0, + outputTokens: $response['meta']['billed_units']['output_tokens'] ?? 0, ); } } diff --git a/src/Extras/Embeddings/Drivers/GeminiDriver.php b/src/Extras/Embeddings/Drivers/GeminiDriver.php index b02ee957..3ef574aa 100644 --- a/src/Extras/Embeddings/Drivers/GeminiDriver.php +++ b/src/Extras/Embeddings/Drivers/GeminiDriver.php @@ -7,6 +7,7 @@ use Cognesy\Instructor\Extras\Embeddings\EmbeddingsResponse; use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Features\Http\HttpClient; +use Cognesy\Instructor\Features\LLM\Data\Usage; class GeminiDriver implements CanVectorize { @@ -64,12 +65,18 @@ protected function toResponse(array $response) : EmbeddingsResponse { } return new EmbeddingsResponse( vectors: $vectors, - inputTokens: $this->inputCharacters, - outputTokens: 0, + usage: $this->makeUsage($response), ); } private function countCharacters(array $input) : int { return array_sum(array_map(fn($item) => strlen($item), $input)); } + + private function makeUsage(array $response) : Usage { + return new Usage( + inputTokens: $this->inputCharacters, + outputTokens: 0, + ); + } } diff --git a/src/Extras/Embeddings/Drivers/JinaDriver.php b/src/Extras/Embeddings/Drivers/JinaDriver.php index 8284bcd7..95e81e8e 100644 --- a/src/Extras/Embeddings/Drivers/JinaDriver.php +++ b/src/Extras/Embeddings/Drivers/JinaDriver.php @@ -8,6 +8,7 @@ use Cognesy\Instructor\Extras\Embeddings\EmbeddingsResponse; use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Features\Http\HttpClient; +use Cognesy\Instructor\Features\LLM\Data\Usage; class JinaDriver implements CanVectorize { @@ -60,6 +61,12 @@ protected function toResponse(array $response) : EmbeddingsResponse { fn($item) => new Vector(values: $item['embedding'], id: $item['index']), $response['data'] ), + usage: $this->makeUsage($response), + ); + } + + private function makeUsage(array $response) : Usage { + return new Usage( inputTokens: $response['usage']['prompt_tokens'] ?? 0, outputTokens: ($response['usage']['total_tokens'] ?? 0) - ($response['usage']['prompt_tokens'] ?? 0), ); diff --git a/src/Extras/Embeddings/Drivers/OpenAIDriver.php b/src/Extras/Embeddings/Drivers/OpenAIDriver.php index 1e165384..d2630e08 100644 --- a/src/Extras/Embeddings/Drivers/OpenAIDriver.php +++ b/src/Extras/Embeddings/Drivers/OpenAIDriver.php @@ -8,6 +8,7 @@ use Cognesy\Instructor\Extras\Embeddings\EmbeddingsResponse; use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Features\Http\HttpClient; +use Cognesy\Instructor\Features\LLM\Data\Usage; class OpenAIDriver implements CanVectorize { @@ -54,6 +55,12 @@ protected function toResponse(array $response) : EmbeddingsResponse { callback: fn($item) => new Vector(values: $item['embedding'], id: $item['index']), array: $response['data'] ), + usage: $this->toUsage($response), + ); + } + + private function toUsage(array $response) : Usage { + return new Usage( inputTokens: $response['usage']['prompt_tokens'] ?? 0, outputTokens: ($response['usage']['total_tokens'] ?? 0) - ($response['usage']['prompt_tokens'] ?? 0), ); diff --git a/src/Extras/Embeddings/EmbeddingsResponse.php b/src/Extras/Embeddings/EmbeddingsResponse.php index 98b4e327..57e073b7 100644 --- a/src/Extras/Embeddings/EmbeddingsResponse.php +++ b/src/Extras/Embeddings/EmbeddingsResponse.php @@ -3,14 +3,14 @@ namespace Cognesy\Instructor\Extras\Embeddings; use Cognesy\Instructor\Extras\Embeddings\Data\Vector; +use Cognesy\Instructor\Features\LLM\Data\Usage; class EmbeddingsResponse { public function __construct( /** @var Vector[] */ public array $vectors, - public int $inputTokens, - public int $outputTokens, + public ?Usage $usage, ) {} public function first() : Vector { @@ -25,6 +25,10 @@ public function all() : array { return $this->vectors; } + public function usage() : Usage { + return $this->usage; + } + /** * @param int $index * @return EmbeddingsResponse[] @@ -32,14 +36,18 @@ public function all() : array { public function split(int $index) : array { return [ new EmbeddingsResponse( - array_slice($this->vectors, 0, $index), - $this->inputTokens, - $this->outputTokens, + vectors: array_slice($this->vectors, 0, $index), + usage: new Usage( + inputTokens: $this->usage()->inputTokens, + outputTokens: $this->usage()->outputTokens, + ), ), new EmbeddingsResponse( - array_slice($this->vectors, $index), - 0, - 0, + vectors: array_slice($this->vectors, $index), + usage: new Usage( + inputTokens: 0, + outputTokens: 0, + ), ), ]; } @@ -52,6 +60,6 @@ public function toValuesArray() : array { } public function totalTokens() : int { - return $this->inputTokens + $this->outputTokens; + return $this->usage()->total(); } } diff --git a/src/Features/Core/InstructorResponse.php b/src/Features/Core/InstructorResponse.php index f85807e6..8d5982b1 100644 --- a/src/Features/Core/InstructorResponse.php +++ b/src/Features/Core/InstructorResponse.php @@ -5,6 +5,7 @@ use Cognesy\Instructor\Events\EventDispatcher; use Cognesy\Instructor\Events\Instructor\InstructorDone; use Cognesy\Instructor\Features\Core\Data\Request; +use Cognesy\Instructor\Features\LLM\Data\LLMResponse; use Exception; class InstructorResponse @@ -35,15 +36,24 @@ public function get() : mixed { return $result->value(); } + /** + * Executes the request and returns LLM response object + */ + public function response() : LLMResponse { + $response = $this->requestHandler->responseFor($this->request); + $this->events->dispatch(new InstructorDone(['result' => $response->value()])); + return $response; + } + /** * Executes the request and returns the response stream */ - public function stream() : Stream { + public function stream() : InstructorStream { // TODO: do we need this? cannot we just turn streaming on? if (!$this->request->isStream()) { throw new Exception('Instructor::stream() method requires response streaming: set "stream" = true in the request options.'); } $stream = $this->requestHandler->streamResponseFor($this->request); - return new Stream($stream, $this->events); + return new InstructorStream($stream, $this->events); } } \ No newline at end of file diff --git a/src/Features/Core/Stream.php b/src/Features/Core/InstructorStream.php similarity index 57% rename from src/Features/Core/Stream.php rename to src/Features/Core/InstructorStream.php index 21f8c138..7a931b3e 100644 --- a/src/Features/Core/Stream.php +++ b/src/Features/Core/InstructorStream.php @@ -8,39 +8,57 @@ use Cognesy\Instructor\Extras\Sequence\Sequence; use Cognesy\Instructor\Features\LLM\Data\LLMResponse; use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Exception; +use Generator; -class Stream +class InstructorStream { - private PartialLLMResponse|LLMResponse|null $lastUpdate = null; + private PartialLLMResponse|LLMResponse|null $lastResponse = null; + private Usage $usage; + /** + * @param Generator $stream + * @param EventDispatcher $events + */ public function __construct( - private Iterable $stream, + private Generator $stream, private EventDispatcher $events, - ) {} + ) { + $this->usage = new Usage(); + } + + /** + * Returns current token usage for the stream + */ + public function usage() : Usage { + return $this->usage; + } /** * Returns last received response object, can be used to retrieve * current or final response from the stream */ public function getLastUpdate() : mixed { - return $this->lastUpdate->data(); + return $this->lastResponse->value(); } /** - * Returns raw stream for custom processing + * Returns last received LLM response object, which contains + * detailed information from LLM API response */ - public function getIterator() : Iterable { - return $this->stream; + public function getLastResponse() : LLMResponse|PartialLLMResponse { + return $this->lastResponse; } /** - * Returns a stream of partial updates + * Returns a stream of partial updates. */ public function partials() : Iterable { foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; $result = $partialResponse->value(); + $this->usage->accumulate($partialResponse->usage()); yield $result; } $this->events->dispatch(new ResponseGenerated($result)); @@ -48,26 +66,30 @@ public function partials() : Iterable { } /** - * Processes response stream and returns the final update + * Processes response stream and returns only the final update. */ public function final() : mixed { foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; + $this->usage->accumulate($partialResponse->usage()); } - $result = $this->lastUpdate->data(); + $result = $this->lastResponse->value(); $this->events->dispatch(new ResponseGenerated($result)); $this->events->dispatch(new InstructorDone(['result' => $result])); return $result; } /** - * Processes response stream, returning updated sequence for each completed item + * Returns single update for each completed item of the sequence. + * This method is useful when you want to process only fully updated + * sequence items, e.g. for visualization or further processing. */ public function sequence() : Iterable { $lastSequence = null; $lastSequenceCount = 1; foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; + $this->usage->accumulate($partialResponse->usage()); $update = $partialResponse->value(); if (!($update instanceof Sequence)) { throw new Exception('Expected a sequence update, got ' . get_class($update)); @@ -85,6 +107,30 @@ public function sequence() : Iterable { $this->events->dispatch(new InstructorDone(['result' => $lastSequence])); } + /** + * Returns a generator of partial LLM responses, which contain more detailed + * information about the response, including usage data. + * @return Generator + */ + public function responses() : Generator { + foreach ($this->stream as $partialResponse) { + $this->usage->accumulate($partialResponse->usage()); + $this->lastResponse = $partialResponse; + yield $partialResponse; + } + $this->events->dispatch(new ResponseGenerated($this->lastResponse->value())); + $this->events->dispatch(new InstructorDone(['result' => $this->lastResponse->value()])); + } + + /** + * Returns raw stream for custom processing. + * Processing with this method does not trigger any events or dispatch any notifications. + * It also does not update usage data on the stream object. + */ + public function getIterator() : Iterable { + return $this->stream; + } + // INTERNAL ////////////////////////////////////////////////////////////// // NOT YET AVAILABLE //////////////////////////////////////////////////// @@ -94,11 +140,11 @@ public function sequence() : Iterable { */ private function each(callable $callback) : void { foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; $callback($partialResponse->value()); } - $this->events->dispatch(new ResponseGenerated($this->lastUpdate)); - $this->events->dispatch(new InstructorDone(['result' => $this->lastUpdate])); + $this->events->dispatch(new ResponseGenerated($this->lastResponse)); + $this->events->dispatch(new InstructorDone(['result' => $this->lastResponse])); } /** @@ -107,10 +153,10 @@ private function each(callable $callback) : void { private function map(callable $callback) : array { $result = []; foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; $result[] = $callback($partialResponse->value()); } - $this->events->dispatch(new ResponseGenerated($this->lastUpdate)); + $this->events->dispatch(new ResponseGenerated($this->lastResponse)); $this->events->dispatch(new InstructorDone(['result' => $result])); return $result; } @@ -121,7 +167,7 @@ private function map(callable $callback) : array { private function flatMap(callable $callback, mixed $initial) : mixed { $result = $initial; foreach ($this->stream as $partialResponse) { - $this->lastUpdate = $partialResponse; + $this->lastResponse = $partialResponse; $result = $callback($partialResponse->value(), $result); } $this->events->dispatch(new ResponseGenerated($result)); diff --git a/src/Features/Core/PartialsGenerator.php b/src/Features/Core/PartialsGenerator.php index 168852e8..5ecc4c31 100644 --- a/src/Features/Core/PartialsGenerator.php +++ b/src/Features/Core/PartialsGenerator.php @@ -109,7 +109,7 @@ public function getPartialResponses(Generator $stream, ResponseModel $responseMo $this->events->dispatch(new PartialJsonReceived($this->responseJson)); yield $partialResponse - ->withData($result->unwrap()) + ->withValue($result->unwrap()) ->withContent($this->responseText); } $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse())); diff --git a/src/Features/Core/RequestHandler.php b/src/Features/Core/RequestHandler.php index 28766685..0997f472 100644 --- a/src/Features/Core/RequestHandler.php +++ b/src/Features/Core/RequestHandler.php @@ -49,7 +49,7 @@ public function __construct( public function responseFor(Request $request) : LLMResponse { $processingResult = Result::failure("No response generated"); while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - $llmResponse = $this->getInference($request)->asLLMResponse(); + $llmResponse = $this->getInference($request)->response(); $llmResponse->content = match($request->mode()) { Mode::Text => $llmResponse->content, default => Json::from($llmResponse->content)->toString(), @@ -66,12 +66,12 @@ public function responseFor(Request $request) : LLMResponse { /** * Yields response value versions based on streamed responses * @param Request $request - * @return Generator + * @return Generator */ public function streamResponseFor(Request $request) : Generator { $processingResult = Result::failure("No response generated"); while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - $stream = $this->getInference($request)->asPartialLLMResponses(); + $stream = $this->getInference($request)->stream()->responses(); yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel()); $llmResponse = $this->partialsGenerator->getCompleteResponse(); diff --git a/src/Features/LLM/Data/LLMResponse.php b/src/Features/LLM/Data/LLMResponse.php index e7fe423e..88fbb23e 100644 --- a/src/Features/LLM/Data/LLMResponse.php +++ b/src/Features/LLM/Data/LLMResponse.php @@ -14,11 +14,10 @@ public function __construct( public array $toolsData = [], public string $finishReason = '', public ?ToolCalls $toolCalls = null, - public int $inputTokens = 0, - public int $outputTokens = 0, - public int $cacheCreationTokens = 0, - public int $cacheReadTokens = 0, - ) {} + public ?Usage $usage = null, + ) { + $this->usage = $usage ?? new Usage(); + } // STATIC //////////////////////////////////////////////// @@ -45,6 +44,10 @@ public function hasContent() : bool { return $this->content !== ''; } + public function content() : string { + return $this->content; + } + public function json(): string { if (!$this->hasContent()) { return ''; @@ -56,6 +59,10 @@ public function hasToolCalls() : bool { return !empty($this->toolCalls); } + public function usage() : Usage { + return $this->usage ?? new Usage(); + } + // INTERNAL ////////////////////////////////////////////// /** @@ -74,10 +81,11 @@ private function makeFromPartialResponses(array $partialResponses = []) : self { } $content .= $partialResponse->contentDelta; $this->responseData[] = $partialResponse->responseData; - $this->inputTokens += $partialResponse->inputTokens; - $this->outputTokens += $partialResponse->outputTokens; - $this->cacheCreationTokens += $partialResponse->cacheCreationTokens; - $this->cacheReadTokens += $partialResponse->cacheReadTokens; + $this->usage()->inputTokens += $partialResponse->usage()->inputTokens; + $this->usage()->outputTokens += $partialResponse->usage()->outputTokens; + $this->usage()->cacheWriteTokens += $partialResponse->usage()->cacheWriteTokens; + $this->usage()->cacheReadTokens += $partialResponse->usage()->cacheReadTokens; + $this->usage()->reasoningTokens += $partialResponse->usage()->reasoningTokens; $this->finishReason = $partialResponse->finishReason; } $this->content = $content; diff --git a/src/Features/LLM/Data/PartialLLMResponse.php b/src/Features/LLM/Data/PartialLLMResponse.php index cf57959c..9904b357 100644 --- a/src/Features/LLM/Data/PartialLLMResponse.php +++ b/src/Features/LLM/Data/PartialLLMResponse.php @@ -6,7 +6,7 @@ class PartialLLMResponse { - private mixed $data = null; // data extracted from response or tool calls + private mixed $value = null; // data extracted from response or tool calls private string $content = ''; public function __construct( @@ -15,25 +15,22 @@ public function __construct( public string $toolName = '', public string $toolArgs = '', public string $finishReason = '', - public int $inputTokens = 0, - public int $outputTokens = 0, - public int $cacheCreationTokens = 0, - public int $cacheReadTokens = 0, + public ?Usage $usage = null, ) {} // PUBLIC //////////////////////////////////////////////// - public function hasData() : bool { - return $this->data !== null; + public function hasValue() : bool { + return $this->value !== null; } - public function withData(mixed $value) : self { - $this->data = $value; + public function withValue(mixed $value) : self { + $this->value = $value; return $this; } - public function data() : mixed { - return $this->data; + public function value() : mixed { + return $this->value; } public function hasContent() : bool { @@ -45,11 +42,6 @@ public function withContent(string $content) : self { return $this; } - public function withFinishReason(string $finishReason) : self { - $this->finishReason = $finishReason; - return $this; - } - public function content() : string { return $this->content; } @@ -60,4 +52,13 @@ public function json(): string { } return Json::fromPartial($this->content)->toString(); } + + public function withFinishReason(string $finishReason) : self { + $this->finishReason = $finishReason; + return $this; + } + + public function usage() : Usage { + return $this->usage ?? new Usage(); + } } diff --git a/src/Features/LLM/Data/Usage.php b/src/Features/LLM/Data/Usage.php new file mode 100644 index 00000000..914c83b3 --- /dev/null +++ b/src/Features/LLM/Data/Usage.php @@ -0,0 +1,30 @@ +inputTokens + + $this->outputTokens + + $this->cacheWriteTokens + + $this->cacheReadTokens + + $this->reasoningTokens; + } + + public function accumulate(Usage $usage) : void { + $this->inputTokens += $usage->inputTokens; + $this->outputTokens += $usage->outputTokens; + $this->cacheWriteTokens += $usage->cacheWriteTokens; + $this->cacheReadTokens += $usage->cacheReadTokens; + $this->reasoningTokens += $usage->reasoningTokens; + } +} diff --git a/src/Features/LLM/Drivers/AnthropicDriver.php b/src/Features/LLM/Drivers/AnthropicDriver.php index 5ffbcb4f..093b378f 100644 --- a/src/Features/LLM/Drivers/AnthropicDriver.php +++ b/src/Features/LLM/Drivers/AnthropicDriver.php @@ -12,6 +12,7 @@ use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Features\LLM\Data\ToolCall; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Features\LLM\InferenceRequest; use Cognesy\Instructor\Utils\Json\Json; use Cognesy\Instructor\Utils\Messages\Messages; @@ -98,10 +99,7 @@ public function toLLMResponse(array $data): ?LLMResponse { toolsData: $this->mapToolsData($data), finishReason: $data['stop_reason'] ?? '', toolCalls: $this->makeToolCalls($data), - inputTokens: $data['usage']['input_tokens'] ?? 0, - outputTokens: $data['usage']['output_tokens'] ?? 0, - cacheCreationTokens: $data['usage']['cache_creation_input_tokens'] ?? 0, - cacheReadTokens: $data['usage']['cache_read_input_tokens'] ?? 0, + usage: $this->makeUsage($data), ); } @@ -115,10 +113,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse { toolName: $data['content_block']['name'] ?? '', toolArgs: $data['delta']['partial_json'] ?? '', finishReason: $data['delta']['stop_reason'] ?? $data['message']['stop_reason'] ?? '', - inputTokens: $data['message']['usage']['input_tokens'] ?? $data['usage']['input_tokens'] ?? 0, - outputTokens: $data['message']['usage']['output_tokens'] ?? $data['usage']['output_tokens'] ?? 0, - cacheCreationTokens: $data['message']['usage']['cache_creation_input_tokens'] ?? $data['usage']['cache_creation_input_tokens'] ?? 0, - cacheReadTokens: $data['message']['usage']['cache_read_input_tokens'] ?? $data['usage']['cache_read_input_tokens'] ?? 0, + usage: $this->makeUsage($data), ); } @@ -284,4 +279,22 @@ private function setCacheMarker(array $messages): array { $messages[$lastIndex] = $lastMessage; return $messages; } + + private function makeUsage(array $data) : Usage { + return new Usage( + inputTokens: $data['usage']['input_tokens'] + ?? $data['message']['usage']['input_tokens'] + ?? 0, + outputTokens: $data['usage']['output_tokens'] + ?? $data['message']['usage']['output_tokens'] + ?? 0, + cacheWriteTokens: $data['usage']['cache_creation_input_tokens'] + ?? $data['message']['usage']['cache_creation_input_tokens'] + ?? 0, + cacheReadTokens: $data['usage']['cache_read_input_tokens'] + ?? $data['message']['usage']['cache_read_input_tokens'] + ?? 0, + reasoningTokens: 0, + ); + } } diff --git a/src/Features/LLM/Drivers/CohereV1Driver.php b/src/Features/LLM/Drivers/CohereV1Driver.php index 927310cd..156c8a2a 100644 --- a/src/Features/LLM/Drivers/CohereV1Driver.php +++ b/src/Features/LLM/Drivers/CohereV1Driver.php @@ -12,6 +12,7 @@ use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Features\LLM\Data\ToolCall; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Features\LLM\InferenceRequest; use Cognesy\Instructor\Utils\Json\Json; use Cognesy\Instructor\Utils\Messages\Messages; @@ -89,10 +90,7 @@ public function toLLMResponse(array $data): LLMResponse { toolsData: $this->mapToolsData($data), finishReason: $data['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), - inputTokens: $data['meta']['tokens']['input_tokens'] ?? 0, - outputTokens: $data['meta']['tokens']['output_tokens'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -103,10 +101,7 @@ public function toPartialLLMResponse(array $data) : PartialLLMResponse { toolName: $this->makeToolNameDelta($data), toolArgs: $this->makeToolArgsDelta($data), finishReason: $data['response']['finish_reason'] ?? $data['delta']['finish_reason'] ?? '', - inputTokens: $data['response']['meta']['tokens']['input_tokens'] ?? $data['delta']['tokens']['input_tokens'] ?? 0, - outputTokens: $data['response']['meta']['tokens']['output_tokens'] ?? $data['delta']['tokens']['input_tokens'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -243,4 +238,20 @@ private function withCachedContext(InferenceRequest $request): InferenceRequest $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; return $cloned; } + + private function makeUsage(array $data) : Usage { + return new Usage( + inputTokens: $data['meta']['tokens']['input_tokens'] + ?? $data['response']['meta']['tokens']['input_tokens'] + ?? $data['delta']['tokens']['input_tokens'] + ?? 0, + outputTokens: $data['meta']['tokens']['output_tokens'] + ?? $data['response']['meta']['tokens']['output_tokens'] + ?? $data['delta']['tokens']['input_tokens'] + ?? 0, + cacheWriteTokens: 0, + cacheReadTokens: 0, + reasoningTokens: 0, + ); + } } diff --git a/src/Features/LLM/Drivers/CohereV2Driver.php b/src/Features/LLM/Drivers/CohereV2Driver.php index 30c97fac..a36ca116 100644 --- a/src/Features/LLM/Drivers/CohereV2Driver.php +++ b/src/Features/LLM/Drivers/CohereV2Driver.php @@ -6,6 +6,7 @@ use Cognesy\Instructor\Features\LLM\Data\LLMResponse; use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Utils\Arrays; use Cognesy\Instructor\Utils\Json\Json; @@ -50,10 +51,7 @@ public function toLLMResponse(array $data): LLMResponse { toolsData: $this->makeToolsData($data), finishReason: $data['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), - inputTokens: $data['meta']['billed_units']['input_tokens'] ?? 0, - outputTokens: $data['meta']['billed_units']['output_tokens'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -67,10 +65,7 @@ public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse { toolName: $data['delta']['message']['tool_calls']['function']['name'] ?? '', toolArgs: $data['delta']['message']['tool_calls']['function']['arguments'] ?? '', finishReason: $data['delta']['finish_reason'] ?? '', - inputTokens: $data['delta']['usage']['billed_units']['input_tokens'] ?? 0, - outputTokens: $data['delta']['usage']['billed_units']['output_tokens'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -162,4 +157,18 @@ private function makeContentDelta(array $data): string { private function normalizeContent(array|string $content) : string { return is_array($content) ? $content['text'] : $content; } + + private function makeUsage(array $data) : Usage { + return new Usage( + inputTokens: $data['meta']['billed_units']['input_tokens'] + ?? $data['delta']['usage']['billed_units']['input_tokens'] + ?? 0, + outputTokens: $data['meta']['billed_units']['output_tokens'] + ?? $data['delta']['usage']['billed_units']['output_tokens'] + ?? 0, + cacheWriteTokens: 0, + cacheReadTokens: 0, + reasoningTokens: 0, + ); + } } diff --git a/src/Features/LLM/Drivers/GeminiDriver.php b/src/Features/LLM/Drivers/GeminiDriver.php index 342d4d6a..9aeec139 100644 --- a/src/Features/LLM/Drivers/GeminiDriver.php +++ b/src/Features/LLM/Drivers/GeminiDriver.php @@ -12,6 +12,7 @@ use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Features\LLM\Data\ToolCall; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Features\LLM\InferenceRequest; use Cognesy\Instructor\Utils\Arrays; use Cognesy\Instructor\Utils\Json\Json; @@ -104,10 +105,7 @@ public function toLLMResponse(array $data): ?LLMResponse { toolsData: $this->mapToolsData($data), finishReason: $data['candidates'][0]['finishReason'] ?? '', toolCalls: $this->makeToolCalls($data), - inputTokens: $data['usageMetadata']['promptTokenCount'] ?? 0, - outputTokens: $data['usageMetadata']['candidatesTokenCount'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -121,10 +119,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse { toolName: $this->makeToolName($data), toolArgs: $this->makeToolArgs($data), finishReason: $data['candidates'][0]['finishReason'] ?? '', - inputTokens: $data['usageMetadata']['promptTokenCount'] ?? 0, - outputTokens: $data['usageMetadata']['candidatesTokenCount'] ?? 0, - cacheCreationTokens: 0, - cacheReadTokens: 0, + usage: $this->makeUsage($data), ); } @@ -309,4 +304,14 @@ private function makeToolArgs(array $data) : string { $value = $data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? ''; return is_array($value) ? Json::encode($value) : ''; } + + private function makeUsage(array $data) : Usage { + return new Usage( + inputTokens: $data['usageMetadata']['promptTokenCount'] ?? 0, + outputTokens: $data['usageMetadata']['candidatesTokenCount'] ?? 0, + cacheWriteTokens: 0, + cacheReadTokens: 0, + reasoningTokens: 0, + ); + } } diff --git a/src/Features/LLM/Drivers/OpenAIDriver.php b/src/Features/LLM/Drivers/OpenAIDriver.php index 87ac6bb0..dcb46c3f 100644 --- a/src/Features/LLM/Drivers/OpenAIDriver.php +++ b/src/Features/LLM/Drivers/OpenAIDriver.php @@ -12,6 +12,7 @@ use Cognesy\Instructor\Features\LLM\Data\LLMResponse; use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; +use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Features\LLM\InferenceRequest; use Cognesy\Instructor\Utils\Json\Json; @@ -92,10 +93,7 @@ public function toLLMResponse(array $data): ?LLMResponse { toolsData: $this->makeToolsData($data), finishReason: $data['choices'][0]['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), - inputTokens: $this->makeInputTokens($data), - outputTokens: $this->makeOutputTokens($data), - cacheCreationTokens: 0, - cacheReadTokens: $data['usage']['prompt_tokens_details']['cached_tokens'] ?? 0, + usage: $this->makeUsage($data), ); } @@ -109,10 +107,7 @@ public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse { toolName: $this->makeToolNameDelta($data), toolArgs: $this->makeToolArgsDelta($data), finishReason: $data['choices'][0]['finish_reason'] ?? '', - inputTokens: $this->makeInputTokens($data), - outputTokens: $this->makeOutputTokens($data), - cacheCreationTokens: 0, - cacheReadTokens: $data['usage']['prompt_tokens_details']['cached_tokens'] ?? 0, + usage: $this->makeUsage($data), ); } @@ -155,6 +150,19 @@ private function applyMode( return $request; } + private function withCachedContext(InferenceRequest $request): InferenceRequest { + if (!isset($request->cachedContext)) { + return $request; + } + + $cloned = clone $request; + $cloned->messages = array_merge($request->cachedContext->messages, $request->messages); + $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; + $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; + $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; + return $cloned; + } + private function makeToolCalls(array $data) : ToolCalls { return ToolCalls::fromArray(array_map( callback: fn(array $call) => $call['function'] ?? [], @@ -192,18 +200,6 @@ private function makeContentDelta(array $data): string { }; } - private function makeInputTokens(array $data): int { - return $data['usage']['prompt_tokens'] - ?? $data['x_groq']['usage']['prompt_tokens'] - ?? 0; - } - - private function makeOutputTokens(array $data): int { - return $data['usage']['completion_tokens'] - ?? $data['x_groq']['usage']['completion_tokens'] - ?? 0; - } - private function makeToolNameDelta(array $data) : string { return $data['choices'][0]['delta']['tool_calls'][0]['function']['name'] ?? ''; } @@ -212,16 +208,17 @@ private function makeToolArgsDelta(array $data) : string { return $data['choices'][0]['delta']['tool_calls'][0]['function']['arguments'] ?? ''; } - private function withCachedContext(InferenceRequest $request): InferenceRequest { - if (!isset($request->cachedContext)) { - return $request; - } - - $cloned = clone $request; - $cloned->messages = array_merge($request->cachedContext->messages, $request->messages); - $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; - $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; - $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; - return $cloned; + private function makeUsage(array $data): Usage { + return new Usage( + inputTokens: $data['usage']['prompt_tokens'] + ?? $data['x_groq']['usage']['prompt_tokens'] + ?? 0, + outputTokens: $data['usage']['completion_tokens'] + ?? $data['x_groq']['usage']['completion_tokens'] + ?? 0, + cacheWriteTokens: 0, + cacheReadTokens: $data['usage']['prompt_tokens_details']['cached_tokens'] ?? 0, + reasoningTokens: 0, + ); } } diff --git a/src/Features/Http/IterableReader.php b/src/Features/LLM/EventStreamReader.php similarity index 90% rename from src/Features/Http/IterableReader.php rename to src/Features/LLM/EventStreamReader.php index 51977e26..42cc9efa 100644 --- a/src/Features/Http/IterableReader.php +++ b/src/Features/LLM/EventStreamReader.php @@ -1,6 +1,6 @@ $stream * @return Generator */ - public function toStreamEvents(Generator $stream): Generator { + public function eventsFrom(Generator $stream): Generator { foreach ($this->readLines($stream) as $line) { $processedData = $this->processLine($line); if ($processedData !== null) { diff --git a/src/Features/LLM/InferenceResponse.php b/src/Features/LLM/InferenceResponse.php index bd51ec62..533aaa74 100644 --- a/src/Features/LLM/InferenceResponse.php +++ b/src/Features/LLM/InferenceResponse.php @@ -4,32 +4,34 @@ use Cognesy\Instructor\Events\EventDispatcher; use Cognesy\Instructor\Events\Inference\LLMResponseReceived; -use Cognesy\Instructor\Events\Inference\PartialLLMResponseReceived; use Cognesy\Instructor\Features\Http\Contracts\CanAccessResponse; -use Cognesy\Instructor\Features\Http\IterableReader; use Cognesy\Instructor\Features\LLM\Contracts\CanHandleInference; use Cognesy\Instructor\Features\LLM\Data\LLMConfig; use Cognesy\Instructor\Features\LLM\Data\LLMResponse; -use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; use Cognesy\Instructor\Utils\Json\Json; -use Generator; use InvalidArgumentException; class InferenceResponse { protected EventDispatcher $events; - protected IterableReader $reader; + protected CanAccessResponse $response; + protected CanHandleInference $driver; protected string $responseContent = ''; + protected LLMConfig $config; + protected bool $isStreamed = false; public function __construct( - protected CanAccessResponse $response, - protected CanHandleInference $driver, - protected LLMConfig $config, - protected bool $isStreamed = false, - ?EventDispatcher $events = null, + CanAccessResponse $response, + CanHandleInference $driver, + LLMConfig $config, + bool $isStreamed = false, + ?EventDispatcher $events = null, ) { $this->events = $events ?? new EventDispatcher(); - $this->reader = new IterableReader($this->driver->getData(...), $this->events); + $this->driver = $driver; + $this->config = $config; + $this->isStreamed = $isStreamed; + $this->response = $response; } public function isStreamed() : bool { @@ -38,8 +40,8 @@ public function isStreamed() : bool { public function toText() : string { return match($this->isStreamed) { - false => $this->asLLMResponse()->content, - true => $this->finalResponse($this->asPartialLLMResponses())->content(), + false => $this->makeLLMResponse()->content(), + true => $this->stream()->final()?->content() ?? '', }; } @@ -48,102 +50,50 @@ public function toJson() : array { } /** - * @return Generator + * @return InferenceStream */ - public function stream() : Generator { + public function stream() : InferenceStream { if (!$this->isStreamed) { throw new InvalidArgumentException('Trying to read response stream for request with no streaming'); } - foreach ($this->asPartialLLMResponses() as $partialLLMResponse) { - yield $partialLLMResponse; - } + return new InferenceStream($this->response, $this->driver, $this->config, $this->events); } // AS API RESPONSE OBJECTS ////////////////////////////////// - public function asLLMResponse() : LLMResponse { + public function response() : LLMResponse { $response = match($this->isStreamed) { - false => $this->driver->toLLMResponse($this->responseData()), - true => LLMResponse::fromPartialResponses($this->allPartialLLMResponses()), + false => $this->makeLLMResponse(), + true => LLMResponse::fromPartialResponses($this->stream()->all()), }; - $this->events->dispatch(new LLMResponseReceived($response)); return $response; } - /** - * @return Generator - */ - public function asPartialLLMResponses() : Generator { - $content = ''; - $finishReason = ''; - foreach ($this->reader->toStreamEvents($this->response->streamContents()) as $streamEvent) { - if ($streamEvent === false) { - continue; - } - $data = Json::decode($streamEvent, []); - $partialResponse = $this->driver->toPartialLLMResponse($data); - if ($partialResponse === null) { - continue; - } - if ($partialResponse->finishReason !== '') { - $finishReason = $partialResponse->finishReason; - } - $content .= $partialResponse->contentDelta; - // add accumulated content and last finish reason - $enrichedResponse = $partialResponse - ->withContent($content) - ->withFinishReason($finishReason); - $this->events->dispatch(new PartialLLMResponseReceived($enrichedResponse)); - yield $enrichedResponse; - } - } - - /** - * @return array[] - */ - public function asArray() : array { - return match($this->isStreamed) { - false => $this->responseData(), - true => $this->allStreamResponses(), - }; - } - // INTERNAL ///////////////////////////////////////////////// - protected function responseData() : array { - if (empty($this->responseContent)) { - $this->responseContent = $this->response->getContents(); - } - return Json::decode($this->responseContent) ?? []; + private function makeLLMResponse() : LLMResponse { + $response = $this->driver->toLLMResponse($this->getResponseData()); + $this->events->dispatch(new LLMResponseReceived($response)); + return $response; } + // PRIVATE ///////////////////////////////////////////////// + /** - * @return array[] + * @return array */ - protected function allStreamResponses() : array { - $content = []; - foreach ($this->reader->toStreamEvents($this->response->streamContents()) as $partialData) { - $content[] = Json::decode($partialData); - } - return $content; + private function getResponseData() : array { + return Json::decode($this->getResponseContent()) ?? []; } - /** - * @return PartialLLMResponse[] - */ - protected function allPartialLLMResponses() : array { - $partialResponses = []; - foreach ($this->asPartialLLMResponses() as $partialResponse) { - $partialResponses[] = $partialResponse; + private function getResponseContent() : string { + if (empty($this->responseContent)) { + $this->responseContent = $this->readFromResponse(); } - return $partialResponses; + return $this->responseContent; } - protected function finalResponse(Generator $partialResponses) : PartialLLMResponse { - $lastPartial = null; - foreach ($partialResponses as $partialResponse) { - $lastPartial = $partialResponse; - } - return $lastPartial; + private function readFromResponse() : string { + return $this->response->getContents(); } } \ No newline at end of file diff --git a/src/Features/LLM/InferenceStream.php b/src/Features/LLM/InferenceStream.php new file mode 100644 index 00000000..d98547f6 --- /dev/null +++ b/src/Features/LLM/InferenceStream.php @@ -0,0 +1,147 @@ +events = $events ?? new EventDispatcher(); + $this->driver = $driver; + $this->config = $config; + $this->response = $response; + + $this->stream = $this->response->streamContents(); + $this->reader = new EventStreamReader($this->driver->getData(...), $this->events); + } + + /** + * @return Generator + */ + public function responses() : Generator { + foreach ($this->makePartialLLMResponses($this->stream) as $partialLLMResponse) { + yield $partialLLMResponse; + } + } + + /** + * @return PartialLLMResponse[] + */ + public function all() : array { + return $this->getAllPartialLLMResponses($this->stream); + } + + /** + * Returns the last partial response for the stream. + * It will contain accumulated content and finish reason. + * @return ?PartialLLMResponse + */ + public function final() : ?PartialLLMResponse { + return $this->finalResponse($this->stream); + } + + // INTERNAL ////////////////////////////////////////////// + + /** + * @param Generator $partialResponses + * @return ?PartialLLMResponse + */ + protected function finalResponse(Generator $partialResponses) : ?PartialLLMResponse { + $lastPartial = null; + foreach ($partialResponses as $partialResponse) { + $lastPartial = $partialResponse; + } + return $lastPartial; + } + + /** + * @param Generator $stream + * @return PartialLLMResponse[] + */ + protected function getAllPartialLLMResponses(Generator $stream) : array { + $partialResponses = []; + foreach ($this->makePartialLLMResponses($stream) as $partialResponse) { + $partialResponses[] = $partialResponse; + } + return $partialResponses; + } + + /** + * @param Generator $stream + * @return Generator + */ + private function makePartialLLMResponses(Generator $stream) : Generator { + $content = ''; + $finishReason = ''; + foreach ($this->getEventStream($stream) as $streamEvent) { + if ($streamEvent === false) { + continue; + } + $data = Json::decode($streamEvent, []); + $partialResponse = $this->makePartialLLMResponse($data); + if ($partialResponse === null) { + continue; + } + if ($partialResponse->finishReason !== '') { + $finishReason = $partialResponse->finishReason; + } + $content .= $partialResponse->contentDelta; + // add accumulated content and last finish reason + $enrichedResponse = $partialResponse + ->withContent($content) + ->withFinishReason($finishReason); + $this->events->dispatch(new PartialLLMResponseReceived($enrichedResponse)); + yield $enrichedResponse; + } + } + + private function makePartialLLMResponse(array $data) : ?PartialLLMResponse { + return $this->driver->toPartialLLMResponse($data); + } + + /** + * @return Generator + */ + private function getEventStream(Generator $stream) : Generator { + if (!$this->streamReceived) { + foreach($this->streamFromResponse($stream) as $event) { + $this->streamEvents[] = $event; + yield $event; + } + $this->streamReceived = true; + return; + } + reset($this->streamEvents); + yield from $this->streamEvents; + } + + /** + * @return Generator + */ + private function streamFromResponse(Generator $stream) : Generator { + return $this->reader->eventsFrom($stream); + } +} \ No newline at end of file diff --git a/src/Utils/Json/Json.php b/src/Utils/Json/Json.php index 63397e72..7e78e8af 100644 --- a/src/Utils/Json/Json.php +++ b/src/Utils/Json/Json.php @@ -2,6 +2,8 @@ namespace Cognesy\Instructor\Utils\Json; +use JsonException; + class Json { private string $json; @@ -49,7 +51,14 @@ public static function decode(string $text, mixed $default = null) : mixed { if (empty($text)) { return $default; } - $decoded = json_decode($text, true, 512, JSON_THROW_ON_ERROR); + try { + $decoded = json_decode($text, true, 512, JSON_THROW_ON_ERROR); + } catch (JsonException $e) { + if ($default === null) { + throw $e; + } + return $default; + } return empty($decoded) ? $default : $decoded; } diff --git a/tests/Feature/Events/EventsTest.php b/tests/Feature/Events/EventsTest.php index 327e4f57..e1de5168 100644 --- a/tests/Feature/Events/EventsTest.php +++ b/tests/Feature/Events/EventsTest.php @@ -29,7 +29,7 @@ '{"name": "Jason", "age":28}', ]); $events = new EventSink(); - $person = (new Instructor)->withDriver($mockLLM) + $person = (new Instructor)->withHttpClient($mockLLM) ->onEvent($event, fn($e) => $events->onEvent($e)) //->wiretap(fn($e) => dump($e)) ->respond( @@ -85,7 +85,7 @@ // expect exception $this->expectException(\Exception::class); - $person = (new Instructor)->withDriver($mockLLM) + $person = (new Instructor)->withHttpClient($mockLLM) ->onEvent($event, fn($e) => $events->onEvent($e)) ->respond( messages: [['role' => 'user', 'content' => $text]], @@ -135,7 +135,7 @@ '{"age":28}' ]); $events = new EventSink(); - $age = (new Instructor)->withDriver($mockLLM) + $age = (new Instructor)->withHttpClient($mockLLM) ->onEvent($event, fn($e) => $events->onEvent($e)) ->respond( messages: [['role' => 'user', 'content' => $text]], diff --git a/tests/Feature/Extras/MaybeTest.php b/tests/Feature/Extras/MaybeTest.php index d4534c8a..90402f38 100644 --- a/tests/Feature/Extras/MaybeTest.php +++ b/tests/Feature/Extras/MaybeTest.php @@ -11,7 +11,7 @@ ]); $text = "His name is Jason, he is 28 years old."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, ); diff --git a/tests/Feature/Extras/MixinTest.php b/tests/Feature/Extras/MixinTest.php index d7140bc3..8aa4341b 100644 --- a/tests/Feature/Extras/MixinTest.php +++ b/tests/Feature/Extras/MixinTest.php @@ -10,7 +10,7 @@ '{"name":"Jason","age":28}' ]); - $instructor = (new Instructor)->withDriver($mockLLM); + $instructor = (new Instructor)->withHttpClient($mockLLM); $person = PersonWithMixin::infer( messages: "His name is Jason, he is 28 years old.", instructor: $instructor diff --git a/tests/Feature/Extras/ScalarsTest.php b/tests/Feature/Extras/ScalarsTest.php index 7088c0db..e309ebc4 100644 --- a/tests/Feature/Extras/ScalarsTest.php +++ b/tests/Feature/Extras/ScalarsTest.php @@ -10,7 +10,7 @@ $mockLLM = MockLLM::get(['{"age":28}']); $text = "His name is Jason, he is 28 years old."; - $value = (new Instructor)->withDriver($mockLLM)->respond( + $value = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [ ['role' => 'system', 'content' => $text], ['role' => 'user', 'content' => 'What is Jason\'s age?'], @@ -25,7 +25,7 @@ $mockLLM = MockLLM::get(['{"firstName":"Jason"}']); $text = "His name is Jason, he is 28 years old."; - $value = (new Instructor)->withDriver($mockLLM)->respond( + $value = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [ ['role' => 'system', 'content' => $text], ['role' => 'user', 'content' => 'What is his name?'], @@ -40,7 +40,7 @@ $mockLLM = MockLLM::get(['{"recordTime":11.6}']); $text = "His name is Jason, he is 28 years old and his 100m sprint record is 11.6 seconds."; - $value = (new Instructor)->withDriver($mockLLM)->respond( + $value = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [ ['role' => 'system', 'content' => $text], ['role' => 'user', 'content' => 'What is Jason\'s best 100m run time?'], @@ -55,7 +55,7 @@ $mockLLM = MockLLM::get(['{"isAdult":true}']); $text = "His name is Jason, he is 28 years old."; - $value = (new Instructor)->withDriver($mockLLM)->respond( + $value = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [ ['role' => 'system', 'content' => $text], ['role' => 'user', 'content' => 'Is he adult?'], @@ -71,7 +71,7 @@ $mockLLM = MockLLM::get(['{"citizenship":"other"}']); $text = "His name is Jason, he is 28 years old and he lives in Germany."; - $value = (new Instructor)->withDriver($mockLLM)->respond( + $value = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [ ['role' => 'system', 'content' => $text], ['role' => 'user', 'content' => 'What is Jason\'s citizenship?'], diff --git a/tests/Feature/Extras/TransformTest.php b/tests/Feature/Extras/TransformTest.php index e7fcae6e..545735f7 100644 --- a/tests/Feature/Extras/TransformTest.php +++ b/tests/Feature/Extras/TransformTest.php @@ -55,7 +55,7 @@ // '{"user_name": "Jason", "user_age": 28}', // ]); // -// $instructor = (new Instructor)->withDriver($mockLLM); +// $instructor = (new Instructor)->withHttpClient($mockLLM); // $predict = new Transform( // signature: 'text (email containing user data) -> user_name, user_age:int', // instructor: $instructor @@ -84,7 +84,7 @@ // // $predict = new Transform( // signature: EmailAnalysis::class, -// instructor: (new Instructor)->withDriver($mockLLM) +// instructor: (new Instructor)->withHttpClient($mockLLM) // ); // // $analysis = $predict->withArgs( diff --git a/tests/Feature/Instructor/ExtractionTest.php b/tests/Feature/Instructor/ExtractionTest.php index b5c6dba2..0abd554d 100644 --- a/tests/Feature/Instructor/ExtractionTest.php +++ b/tests/Feature/Instructor/ExtractionTest.php @@ -18,7 +18,7 @@ ]); $text = "His name is Jason, he is 28 years old."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, ); @@ -35,7 +35,7 @@ ]); $text = "His name is Jason, he is 28 years old. He is self-employed."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: PersonWithJob::class, ); @@ -53,7 +53,7 @@ ]); $text = "His name is Jason, he is 28 years old. He lives in San Francisco."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: PersonWithAddress::class, ); @@ -71,7 +71,7 @@ ]); $text = "His name is Jason, he is 28 years old. He lives in USA - he works from his home office in San Francisco, he also has an apartment in New York."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: PersonWithAddresses::class, ); @@ -89,7 +89,7 @@ '{"events":[{"title":"Project Status RED","description":"Acme Insurance project to implement SalesTech CRM solution is currently in RED status due to delayed delivery of document production system, led by 3rd party vendor - Alfatech.","type":"risk","status":"open","stakeholders":[{"name":"Alfatech","role":"vendor"},{"name":"Acme","role":"customer"}],"date":"2021-09-01"},{"title":"Ecommerce Track Delay","description":"Due to dependencies, the ecommerce track will be delayed by 2 sprints because of the delayed delivery of the document production system.","type":"issue","status":"open","stakeholders":[{"name":"Acme","role":"customer"},{"name":"SysCorp","role":"system integrator"}]},{"title":"Test Data Availability Issue","description":"customer is not able to provide the test data for the ecommerce track, which will impact the stabilization schedule unless resolved by the end of the month.","type":"issue","status":"open","stakeholders":[{"name":"Acme","role":"customer"},{"name":"SysCorp","role":"system integrator"}]},{"title":"Steerco Maintains Schedule","description":"Steerco insists on maintaining the release schedule due to marketing campaign already ongoing, regardless of the project issues.","type":"issue","status":"open","stakeholders":[{"name":"Acme","role":"customer"}]},{"title":"Communication Issues","description":"SalesTech team struggling with communication issues as SysCorp team has not shown up on 2 recent calls, leading to lack of insight. This has been escalated to SysCorp\'s leadership team.","type":"issue","status":"open","stakeholders":[{"name":"SysCorp","role":"system integrator"},{"name":"Acme","role":"customer"}]},{"title":"Integration Proxy Issue Resolved","description":"The previously reported Integration Proxy connectivity issue, which was blocking the policy track, has been resolved.","type":"progress","status":"closed","stakeholders":[{"name":"SysCorp","role":"system integrator"}],"date":"2021-08-30"},{"title":"Finalized Production Deployment Plan","description":"Production deployment plan has been finalized on Aug 15th and is awaiting customer approval.","type":"progress","status":"open","stakeholders":[{"name":"Acme","role":"customer"}],"date":"2021-08-15"}]}' ]); - $instructor = (new Instructor)->withDriver($mockLLM); //$mockLLM + $instructor = (new Instructor)->withHttpClient($mockLLM); //$mockLLM /** @var \Tests\Examples\Complex\ProjectEvents $events */ $events = $instructor ->respond( diff --git a/tests/Feature/Instructor/FeaturesTest.php b/tests/Feature/Instructor/FeaturesTest.php index c41d20a0..48b67e19 100644 --- a/tests/Feature/Instructor/FeaturesTest.php +++ b/tests/Feature/Instructor/FeaturesTest.php @@ -7,7 +7,7 @@ it('accepts string as input', function () { $mockLLM = MockLLM::get(['{"name":"Jason","age":28}']); - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: "His name is Jason, he is 28 years old.", responseModel: Person::class, ); @@ -25,7 +25,7 @@ ]); $text = "His name is JX, aka Jason, is -28 years old."; - $person = (new Instructor)->withDriver($mockLLM)->respond( + $person = (new Instructor)->withHttpClient($mockLLM)->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, maxRetries: 2, diff --git a/tests/Feature/Instructor/InstructorTest.php b/tests/Feature/Instructor/InstructorTest.php index 1d3ad102..acd4cd5c 100644 --- a/tests/Feature/Instructor/InstructorTest.php +++ b/tests/Feature/Instructor/InstructorTest.php @@ -14,7 +14,7 @@ $text = "His name is Jason, he is 28 years old."; it('handles direct call', function () use ($mockLLM, $text) { - $instructor = (new Instructor)->withDriver($mockLLM); + $instructor = (new Instructor)->withHttpClient($mockLLM); $person = $instructor->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, @@ -26,7 +26,7 @@ it('handles onEvent()', function () use ($mockLLM, $text) { $events = new EventSink(); - $instructor = (new Instructor)->withDriver($mockLLM); + $instructor = (new Instructor)->withHttpClient($mockLLM); $person = $instructor->onEvent(RequestReceived::class, fn($e) => $events->onEvent($e))->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, @@ -39,7 +39,7 @@ it('handles wiretap()', function () use ($mockLLM, $text) { $events = new EventSink(); - $instructor = (new Instructor)->withDriver($mockLLM); + $instructor = (new Instructor)->withHttpClient($mockLLM); $person = $instructor->wiretap(fn($e) => $events->onEvent($e))->respond( messages: [['role' => 'user', 'content' => $text]], responseModel: Person::class, diff --git a/tests/Feature/IterableReaderTest.php b/tests/Feature/IterableReaderTest.php index 847c4fb6..5db375d0 100644 --- a/tests/Feature/IterableReaderTest.php +++ b/tests/Feature/IterableReaderTest.php @@ -4,7 +4,7 @@ use Cognesy\Instructor\Events\EventDispatcher; use Cognesy\Instructor\Events\Inference\StreamDataReceived; -use Cognesy\Instructor\Features\Http\IterableReader; +use Cognesy\Instructor\Features\LLM\EventStreamReader; use Mockery as Mock; beforeEach(function () { @@ -14,7 +14,7 @@ }); it('streams synthetic OpenAI streaming data correctly without parser', function () { - $reader = new IterableReader(events: $this->mockEventDispatcher); + $reader = new EventStreamReader(events: $this->mockEventDispatcher); $generator = function () { yield '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{'; @@ -28,13 +28,13 @@ '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{"text": "world!", "index": 1}]}', ]; - $result = iterator_to_array($reader->toStreamEvents($generator())); + $result = iterator_to_array($reader->eventsFrom($generator())); expect($result)->toEqual($expected); }); it('processes synthetic OpenAI streaming data with a custom parser', function () { $parser = fn($line) => strtoupper($line); - $reader = new IterableReader(parser: $parser, events: $this->mockEventDispatcher); + $reader = new EventStreamReader(parser: $parser, events: $this->mockEventDispatcher); $generator = function () { yield '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{'; @@ -48,7 +48,7 @@ '{"ID": "CMPL-XYZ", "OBJECT": "TEXT_COMPLETION", "CHOICES": [{"TEXT": "WORLD!", "INDEX": 1}]}', ]; - $result = iterator_to_array($reader->toStreamEvents($generator())); + $result = iterator_to_array($reader->eventsFrom($generator())); expect($result)->toEqual($expected); }); @@ -69,7 +69,7 @@ //}); it('skips empty lines correctly in synthetic OpenAI data', function () { - $reader = new IterableReader(events: $this->mockEventDispatcher); + $reader = new EventStreamReader(events: $this->mockEventDispatcher); $generator = function () { yield "\n"; @@ -86,12 +86,12 @@ '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{"text": "world!", "index": 1}]}', ]; - $result = iterator_to_array($reader->toStreamEvents($generator())); + $result = iterator_to_array($reader->eventsFrom($generator())); expect($result)->toEqual($expected); }); it('handles incomplete lines correctly in synthetic OpenAI data', function () { - $reader = new IterableReader(events: $this->mockEventDispatcher); + $reader = new EventStreamReader(events: $this->mockEventDispatcher); $generator = function () { yield '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{'; @@ -106,6 +106,6 @@ '{"id": "cmpl-xyz", "object": "text_completion", "choices": [{"text": "world!", "index": 1}]}', ]; - $result = iterator_to_array($reader->toStreamEvents($generator())); + $result = iterator_to_array($reader->eventsFrom($generator())); expect($result)->toEqual($expected); }); \ No newline at end of file diff --git a/tests/MockLLM.php b/tests/MockLLM.php index a2cbd3ef..6386915f 100644 --- a/tests/MockLLM.php +++ b/tests/MockLLM.php @@ -2,42 +2,85 @@ namespace Tests; use Cognesy\Instructor\Features\Http\Contracts\CanAccessResponse; -use Cognesy\Instructor\Features\LLM\Contracts\CanHandleInference; +use Cognesy\Instructor\Features\Http\Contracts\CanHandleHttp; +use Cognesy\Instructor\Features\Http\Drivers\GuzzleDriver; use Cognesy\Instructor\Features\LLM\Data\LLMResponse; -use Cognesy\Instructor\Features\LLM\Drivers\OpenAIDriver; use Mockery; -use Psr\Http\Message\MessageInterface; -use Psr\Http\Message\ResponseInterface; -use Psr\Http\Message\StreamInterface; class MockLLM { - static public function get(array $args) : CanHandleInference { - $mockLLM = Mockery::mock(OpenAIDriver::class); - $mockResponse = Mockery::mock(CanAccessResponse::class, ResponseInterface::class, StreamInterface::class, MessageInterface::class); + static public function get(array $args) : CanHandleHttp { +// $mockLLM = Mockery::mock(OpenAIDriver::class); + $mockHttp = Mockery::mock(GuzzleDriver::class); + $mockResponse = Mockery::mock(CanAccessResponse::class); + $list = []; foreach ($args as $arg) { $list[] = self::makeFunc($arg); } - //$mockLLM->shouldReceive('handle')->andReturnUsing(fn() => new OpenAIApiRequest()); - $mockLLM->shouldReceive('getData')->andReturn(''); - $mockLLM->shouldReceive('handle')->andReturn($mockResponse); - $mockLLM->shouldReceive('getEndpointUrl')->andReturn(''); - $mockLLM->shouldReceive('getRequestHeaders')->andReturn([]); - $mockLLM->shouldReceive('getRequestBody')->andReturnUsing([]); - $mockLLM->shouldReceive('toLLMResponse')->andReturnUsing(...$list); - $mockLLM->shouldReceive('toPartialLLMResponse')->andReturn($mockLLM); + $mockHttp->shouldReceive('handle')->andReturn($mockResponse); + +// $mockLLM->shouldReceive('getData')->andReturn(''); +// $mockLLM->shouldReceive('handle')->andReturn($mockResponse); +// $mockLLM->shouldReceive('getEndpointUrl')->andReturn(''); +// $mockLLM->shouldReceive('getRequestHeaders')->andReturn([]); +// $mockLLM->shouldReceive('getRequestBody')->andReturnUsing([]); +// $mockLLM->shouldReceive('toLLMResponse')->andReturnUsing(...$list); +// $mockLLM->shouldReceive('toPartialLLMResponse')->andReturn($mockLLM); - $mockResponse->shouldReceive('getContents')->andReturn($mockResponse); + $mockResponse->shouldReceive('getStatusCode')->andReturn(200); + $mockResponse->shouldReceive('getHeaders')->andReturn([]); + $mockResponse->shouldReceive('getContents')->andReturnUsing(...$list); $mockResponse->shouldReceive('streamContents')->andReturn($mockResponse); - return $mockLLM; + return $mockHttp; } static private function makeFunc(string $json) { - return fn() => new LLMResponse( - content: $json, - ); + return fn() => json_encode(self::mockOpenAIResponse($json)); + } + + static private function mockOpenAIResponse(string $json) : array { + return [ + "id" => "chatcmpl-AGH2w25Kx4hNnqUgcxqcgnqrzfIaD", + "object" => "chat.completion", + "created" => 1728442138, + "model" => "gpt-4o-mini-2024-07-18", + "choices" => [ + 0 => [ + "index" => 0, + "message" => [ + "role" => "assistant", + "content" => null, + "tool_calls" => [ + 0 => [ + "id" => "call_HGWji0nx7LQsRGGw1ckosq6S", + "type" => "function", + "function" => [ + "name" => "extracted_data", + "arguments" => $json, + ] + ] + ], + "refusal" => null, + ], + "logprobs" => null, + "finish_reason" => "stop", + ] + ], + "usage" => [ + "prompt_tokens" => 95, + "completion_tokens" => 9, + "total_tokens" => 104, + "prompt_tokens_details" => [ + "cached_tokens" => 0, + ], + "completion_tokens_details" => [ + "reasoning_tokens" => 0, + ], + ], + "system_fingerprint" => "fp_f85bea6784", + ]; } }