Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logic to calculate how much space to allocate for completion requests #205

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/AzureExtension/AzureExtension.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
<PackageReference Include="Microsoft.Data.Sqlite" Version="7.0.4" />
<PackageReference Include="Microsoft.Identity.Client" Version="4.56.0" />
<PackageReference Include="Microsoft.Identity.Client.Extensions.Msal" Version="4.56.0" />
<PackageReference Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24271.1" />
adrastogi marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="Microsoft.Toolkit.Uwp.Notifications" Version="7.1.3" />
<PackageReference Include="Microsoft.Windows.CsWin32" Version="0.2.206-beta" />
<PackageReference Include="Microsoft.Windows.CsWinRT" Version="2.0.4" />
Expand All @@ -85,6 +86,7 @@
<PackageReference Include="Serilog.Sinks.Debug" Version="2.0.0" />
<PackageReference Include="Serilog.Sinks.File" Version="5.0.0" />
<PackageReference Include="System.Numerics.Tensors" Version="8.0.0" />
<PackageReference Include="System.Text.Json" Version="8.0.0" />
<PackageReference Include="YamlDotNet" Version="15.1.2" />
</ItemGroup>

Expand Down
42 changes: 38 additions & 4 deletions src/AzureExtension/QuickStartPlayground/AzureOpenAIService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,32 @@
using Azure;
using Azure.AI.OpenAI;
using AzureExtension.Contracts;
using Microsoft.ML.Tokenizers;
using Newtonsoft.Json;
using Serilog;

namespace AzureExtension.QuickStartPlayground;

public sealed class AzureOpenAIService : IAzureOpenAIService
{
private readonly ILogger _log = Serilog.Log.ForContext("SourceContext", nameof(AzureOpenAIService));

private const string AzureOpenAIEmbeddingFile = "ms-appx:///AzureExtension/Assets/QuickStartPlayground/docsEmbeddings-AzureOpenAI.json";
private const string OpenAIEmbeddingFile = "ms-appx:///AzureExtension/Assets/QuickStartPlayground/docsEmbeddings-OpenAI.json";

private readonly OpenAIEndpoint _endpoint;

// We use a Microsoft-published library that implements OpenAI's TikToken algorithm to figure out how much space to
// allocate for the output.
private readonly Tokenizer _gpt3Tokenizer = Tokenizer.CreateTiktokenForModel("gpt-35-turbo-instruct");

// This is the publicly documented gpt-35-turbo-instruct context window length (shared between the input and output)
private readonly int _contextWindowMaxLength = 4096;

// This is a fudge factor in case there is a discrepancy between the tokenizer's output and what the
// model will internally calculate.
private readonly int _contextWindowPadding = 100;

public AzureOpenAIService(IAICredentialService aiCredentialService, OpenAIEndpoint endpoint)
{
AICredentialService = aiCredentialService;
Expand Down Expand Up @@ -127,24 +142,43 @@ public async Task<string> GetAICompletionAsync(string systemInstructions, string
var openAIClient = OpenAIClient;
ArgumentNullException.ThrowIfNull(openAIClient);

var prompts = systemInstructions + "\n\n" + userMessage;
adrastogi marked this conversation as resolved.
Show resolved Hide resolved
adrastogi marked this conversation as resolved.
Show resolved Hide resolved
var maxTokens = ComputeMaxTokens(prompts);

var response = await openAIClient.GetCompletionsAsync(
new CompletionsOptions()
{
DeploymentName = CompletionDeploymentName,
Prompts =
{
systemInstructions + "\n\n" + userMessage,
prompts,
},
Temperature = 0.01F,
MaxTokens = 2000,
MaxTokens = maxTokens,
});

if (response.Value.Choices[0].FinishReason == CompletionsFinishReason.TokenLimitReached)
{
// TODO: Need to handle this
Console.WriteLine("Cut off due to length constraints");
throw new InvalidDataException("Token limit reached while generating response");
}

return response.Value.Choices[0].Text;
}

private int ComputeMaxTokens(string inputPrompt)
{
var promptTokens = _gpt3Tokenizer.CountTokens(inputPrompt);
var maxTokensForResponse = _contextWindowMaxLength - _contextWindowPadding - promptTokens;

if (maxTokensForResponse < 0)
{
throw new InvalidDataException("Input prompt has taken up the entire context window.");
}

_log.Information("Input Prompt:\n{inputPrompt}", inputPrompt);
_log.Information("Tokens used: {promptTokens}", promptTokens);
_log.Information("Max tokens for response: {maxTokensForResponse}", maxTokensForResponse);

return maxTokensForResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static string FindExecutableInPath(string executableName)
}
}

_log.Information("${executableName} not found in PATH.");
_log.Information($"{executableName} not found in PATH.");
return string.Empty;
}

Expand Down
Loading