Skip to content

Commit

Permalink
Evals - cont
Browse files Browse the repository at this point in the history
  • Loading branch information
ddebowczyk committed Sep 25, 2024
1 parent 6c72697 commit 8b59038
Show file tree
Hide file tree
Showing 14 changed files with 164 additions and 67 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ $translation = (new Instructor)->respond(
);

assert($translation instanceof Email); // true
dd($translation);
dump($translation);
// Email {
// address: "joe@gmail",
// subject: "Actualización de estado",
Expand Down
2 changes: 1 addition & 1 deletion config/llm.php
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
'apiUrl' => 'https://openrouter.ai/api/v1/',
'apiKey' => Env::get('OPENROUTER_API_KEY', ''),
'endpoint' => '/chat/completions',
'defaultModel' => 'microsoft/phi-3.5-mini-128k-instruct',
'defaultModel' => 'qwen/qwen-2.5-72b-instruct', //'microsoft/phi-3.5-mini-128k-instruct',
'defaultMaxTokens' => 1024,
'connectTimeout' => 3,
'requestTimeout' => 30,
Expand Down
70 changes: 44 additions & 26 deletions evals/LLMModes/CompareModes.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,38 @@ private function schema() : array {
'type' => 'object',
'description' => 'User data',
'properties' => [
'answer' => [
'age' => [
'type' => 'integer',
'description' => 'Age',
],
'name' => [
'type' => 'string',
'description' => 'City',
'description' => 'Name',
],
],
'required' => ['answer'],
'required' => ['name', 'age'],
'additionalProperties' => false,
];
}

private function tools() : array {
return [[
'type' => 'function',
'description' => 'answer',
'function' => [
'name' => 'answer',
'parameters' => $this->schema,
'name' => 'store_user',
'description' => 'Save user data',
'parameters' => $this->schema(),
],
]];
}

private function responseFormatJsonSchema() : array {
return [
'type' => 'json_schema',
'description' => 'User data',
'json_schema' => [
'name' => 'answer',
'schema' => $this->schema,
'name' => 'store_user',
'schema' => $this->schema(),
'strict' => true,
],
];
Expand All @@ -61,15 +66,15 @@ private function responseFormatJsonSchema() : array {
private function responseFormatJson() : array {
return [
'type' => 'json_object',
'schema' => $this->schema,
'schema' => $this->schema(),
];
}

private function toolChoice() : array {
return [
'type' => 'function',
'function' => [
'name' => 'answer'
'name' => 'store_user'
]
];
}
Expand All @@ -80,14 +85,14 @@ public function executeAll(array $connections, array $modes, array $streamingMod
foreach ($streamingModes as $isStreamed) {
$this->before($mode, $connection, $isStreamed);
$evalResponse = $this->execute($connection, $mode, $isStreamed);
$this->after($evalResponse->answer, $evalResponse->isCorrect, $evalResponse->timeElapsed);
$this->after($evalResponse);
}
}
}

if (!empty($this->exceptions)) {
Console::println('');
Console::println(' EXCEPTIONS ', [Color::BG_RED, Color::YELLOW]);
Console::println(' EXCEPTIONS ', [Color::BG_MAGENTA, Color::WHITE, Color::BOLD]);
foreach($this->exceptions as $key => $exception) {
$exLine = str_replace("\n", '\n', $exception);
echo Console::columns([
Expand All @@ -108,9 +113,10 @@ private function execute(string $connection, Mode $mode, bool $isStreamed) : Eva
$time = microtime(true);
$answer = $this->callInferenceFor($this->query, $mode, $connection, $this->schema(), $isStreamed);
$timeElapsed = microtime(true) - $time;
$isCorrect = ($this->evalFn)(new EvalRequest(
$evalRequest = new EvalRequest(
$answer, $this->query, $this->schema(), $mode, $connection, $isStreamed
));
);
$isCorrect = ($this->evalFn)($evalRequest);
$evalResponse = new EvalResponse(
id: $key,
answer: $answer,
Expand All @@ -127,9 +133,6 @@ private function execute(string $connection, Mode $mode, bool $isStreamed) : Eva
timeElapsed: $timeElapsed,
exception: $e,
);
Console::print(' ');
Console::print(' !!!! ', [Color::RED, Color::BG_BLACK]);
Console::println($this->exc2txt($e, 80), [Color::RED, Color::BG_BLACK]);
}
$this->responses[] = $evalResponse;
return $evalResponse;
Expand All @@ -143,8 +146,7 @@ public function callInferenceFor(string $query, Mode $mode, string $connection,
Mode::MdJson => $this->forModeMdJson($query, $connection, $schema, $isStreamed),
Mode::Text => $this->forModeText($query, $connection, $isStreamed),
};
$answer = $this->getValue($inferenceResult, $isStreamed);
return $answer;
return $this->getValue($inferenceResult, $isStreamed);
}

private function getValue(InferenceResponse $response, bool $isStreamed) : string {
Expand All @@ -167,13 +169,29 @@ private function before(Mode $mode, string $connection, bool $isStreamed) : void
Console::print('', [Color::GRAY, Color::BG_BLACK]);
}

private function after(string $answer, bool $isCorrect, float $timeElapsed) : void {
$answerLine = str_replace("\n", '\n', $answer);
echo Console::columns([
[9, $this->timeFormat($timeElapsed), STR_PAD_LEFT, [Color::DARK_YELLOW]],
[5, $isCorrect ? ' OK ' : ' FAIL', STR_PAD_RIGHT, $isCorrect ? [Color::BG_GREEN, Color::WHITE] : [Color::BG_RED, Color::WHITE]],
[60, ' '.$answerLine, STR_PAD_RIGHT, [Color::WHITE, Color::BG_BLACK]],
], 120);
private function after(EvalResponse $evalResponse) : void {
$answer = $evalResponse->answer;
$isCorrect = $evalResponse->isCorrect;
$timeElapsed = $evalResponse->timeElapsed;
$exception = $evalResponse->exception;

if ($exception) {
//Console::print(' ');
//Console::print(' !!!! ', [Color::RED, Color::BG_BLACK]);
//Console::println(, [Color::RED, Color::BG_BLACK]);
echo Console::columns([
[9, '', STR_PAD_LEFT, [Color::DARK_YELLOW]],
[5, ' !!!!', STR_PAD_RIGHT, [Color::WHITE, COLOR::BOLD, Color::BG_MAGENTA]],
[60, ' ' . $this->exc2txt($exception, 80), STR_PAD_RIGHT, [Color::RED, Color::BG_BLACK]],
], 120);
} else {
$answerLine = str_replace("\n", '\n', $answer);
echo Console::columns([
[9, $this->timeFormat($timeElapsed), STR_PAD_LEFT, [Color::DARK_YELLOW]],
[5, $isCorrect ? ' OK ' : ' FAIL', STR_PAD_RIGHT, $isCorrect ? [Color::BG_GREEN, Color::WHITE] : [Color::BG_RED, Color::WHITE]],
[60, ' ' . $answerLine, STR_PAD_RIGHT, [Color::WHITE, Color::BG_BLACK]],
], 120);
}
echo "\n";
}

Expand Down
34 changes: 17 additions & 17 deletions evals/LLMModes/run.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@
use Cognesy\Instructor\Utils\Str;

$connections = [
// 'anthropic',
// 'azure',
// 'cohere',
// 'fireworks',
'anthropic',
'azure',
'cohere',
'fireworks',
'gemini',
// 'groq',
// 'mistral',
// 'ollama',
// 'openai',
// 'openrouter',
// 'together'
'groq',
'mistral',
'ollama',
'openai',
'openrouter',
'together'
];

$streamingModes = [false];
$streamingModes = [true];
$modes = [
Mode::Text,
Mode::MdJson,
Mode::Json,
Mode::JsonSchema,
Mode::Tools,
// Mode::MdJson,
// Mode::Json,
// Mode::JsonSchema,
// Mode::Tools,
];

(new CompareModes(
query: 'What is the capital of France?',
evalFn: fn(EvalRequest $er) => Str::contains($er->answer, 'Paris'),
query: 'Our user Jason is 28 yo. What is the name and age of the user?',
evalFn: fn(EvalRequest $er) => Str::contains($er->answer, ['28', 'Jason']),
//debug: true,
))->executeAll(
connections: $connections,
Expand Down
5 changes: 4 additions & 1 deletion examples/A05_Extras/LLM/run.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
use Cognesy\Instructor\Extras\LLM\Inference;
use Cognesy\Instructor\Utils\Str;

$answer = (new Inference)->create('What is capital of France?')->toText();
$answer = (new Inference)
->withConnection('openai') // optional, default is set in /config/llm.php
->create(messages: 'What is capital of France?')
->toText();

assert(Str::contains($answer, 'Paris'));
echo $answer;
Expand Down
13 changes: 13 additions & 0 deletions src/ApiClient/Data/ToolCall.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@

namespace Cognesy\Instructor\ApiClient\Data;

use Cognesy\Instructor\Utils\Json\Json;

class ToolCall
{
public function __construct(
public string $name,
public string $args,
) {}

public static function fromArray(array $toolCall) : ToolCall {
return new ToolCall(
name: $toolCall['name'],
args: match(true) {
is_array($toolCall['arguments'] ?? false) => Json::encode($toolCall['arguments']),
is_string($toolCall['arguments'] ?? false) => $toolCall['arguments'],
default => throw new \InvalidArgumentException('ToolCall args must be a string or an array')
}
);
}
}
20 changes: 20 additions & 0 deletions src/ApiClient/Data/ToolCalls.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@ class ToolCalls
*/
private array $toolCalls = [];

public function __construct(array $toolCalls = []) {
$this->toolCalls = $toolCalls;
}

public static function fromArray(array $toolCalls) : ToolCalls {
$list = [];
foreach ($toolCalls as $toolCall) {
$list[] = ToolCall::fromArray($toolCall);
}
return new ToolCalls($list);
}

public static function fromMapper(array $toolCalls, callable $mapper) : ToolCalls {
$list = [];
foreach ($toolCalls as $toolCall) {
$list[] = $mapper($toolCall);
}
return new ToolCalls($list);
}

public function count() : int {
return count($this->toolCalls);
}
Expand Down
7 changes: 6 additions & 1 deletion src/Extras/LLM/Drivers/AnthropicDriver.php
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
<?php
namespace Cognesy\Instructor\Extras\LLM\Drivers;

use Cognesy\Instructor\ApiClient\Data\ToolCall;
use Cognesy\Instructor\ApiClient\Data\ToolCalls;
use Cognesy\Instructor\ApiClient\Responses\ApiResponse;
use Cognesy\Instructor\ApiClient\Responses\PartialApiResponse;
use Cognesy\Instructor\Data\Messages\Messages;
Expand Down Expand Up @@ -28,7 +30,10 @@ public function toApiResponse(array $data): ApiResponse {
responseData: $data,
toolName: $data['content'][0]['name'] ?? '',
finishReason: $data['stop_reason'] ?? '',
toolCalls: null,
toolCalls: ToolCalls::fromMapper(array_map(
callback: fn(array $call) => $call,
array: $data['content'] ?? []
), fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['input'] ?? ''])),
inputTokens: $data['usage']['input_tokens'] ?? 0,
outputTokens: $data['usage']['output_tokens'] ?? 0,
cacheCreationTokens: $data['usage']['cache_creation_input_tokens'] ?? 0,
Expand Down
23 changes: 16 additions & 7 deletions src/Extras/LLM/Drivers/CohereDriver.php
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
<?php
namespace Cognesy\Instructor\Extras\LLM\Drivers;

use Cognesy\Instructor\ApiClient\Data\ToolCall;
use Cognesy\Instructor\ApiClient\Data\ToolCalls;
use Cognesy\Instructor\ApiClient\Responses\ApiResponse;
use Cognesy\Instructor\ApiClient\Responses\PartialApiResponse;
use Cognesy\Instructor\Data\Messages\Messages;
use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Extras\LLM\Data\LLMConfig;
use Cognesy\Instructor\Extras\LLM\Contracts\CanHandleInference;
use Cognesy\Instructor\Extras\LLM\InferenceRequest;
use Cognesy\Instructor\Utils\Json\Json;
use Cognesy\Instructor\Utils\Str;
use GuzzleHttp\Client;

class CohereDriver implements CanHandleInference
Expand All @@ -21,11 +25,17 @@ public function __construct(

public function toApiResponse(array $data): ApiResponse {
return new ApiResponse(
content: $data['text'] ?? '',
content: ($data['text'] ?? '') . (!empty($data['tool_calls'])
? ("\n" . Json::encode($data['tool_calls']))
: ''
),
responseData: $data,
toolName: '',
toolName: $data['tool_calls'][0]['name'] ?? '',
finishReason: $data['finish_reason'] ?? '',
toolCalls: null,
toolCalls: ToolCalls::fromMapper(
$data['tool_calls'] ?? [],
fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['parameters'] ?? ''])
),
inputTokens: $data['meta']['tokens']['input_tokens'] ?? 0,
outputTokens: $data['meta']['tokens']['output_tokens'] ?? 0,
cacheCreationTokens: 0,
Expand Down Expand Up @@ -113,19 +123,18 @@ protected function toTools(array $tools): array {
$parameters = [];
foreach ($tool['function']['parameters']['properties'] as $name => $param) {
$parameters[$name] = array_filter([
//'name' => $name,
'description' => $param['description'] ?? '',
'type' => $this->toCohereType($param),
'required' => in_array(
needle: $name,
haystack: $tools['function']['parameters']['required'] ?? [],
haystack: $tool['function']['parameters']['required'] ?? [],
),
]);
}
$result[] = [
'name' => $tool['function']['name'],
'description' => $tool['function']['description'] ?? 'Extract data from context',
'parameters_definitions' => $parameters,
'description' => $tool['function']['description'] ?? '',
'parameterDefinitions' => $parameters,
];
}
return $result;
Expand Down
Loading

0 comments on commit 8b59038

Please sign in to comment.