Skip to content

Commit

Permalink
Inference cancling all around.
Browse files Browse the repository at this point in the history
  • Loading branch information
edgett committed Dec 12, 2023
1 parent a828b2e commit 81a6d54
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
8 changes: 4 additions & 4 deletions PalmHill.BlazorChat/Client/Services/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ public async Task SendToWebSocketChat()
public async Task AskDocumentApi()
{

CanSend = false;
CanStop = true;

var prompt = new WebSocketChatMessage();
prompt.Prompt = UserInput;
WebsocketChatMessages.Add(prompt);
Expand All @@ -115,12 +118,9 @@ public async Task AskDocumentApi()
}
else
{
prompt.AddResponseString("Error.");
prompt.CompleteResponse(false);
}

StateHasChanged();

SetReady();
}

public async Task SaveSettings()
Expand Down
2 changes: 1 addition & 1 deletion PalmHill.BlazorChat/Server/SignalR/WebSocketChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToCli
modelContext.Dispose();
inferenceStopwatch.Stop();

throw new OperationCanceledException();
throw new OperationCanceledException(cancellationToken);
}

totalTokens++;
Expand Down
54 changes: 40 additions & 14 deletions PalmHill.BlazorChat/Server/WebApi/ApiChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ public async Task<ActionResult<string>> Chat([FromBody] InferenceRequest convers

try
{
await ThreadLock.InferenceLock.WaitAsync();
var response = await DoInference(conversation);
await ThreadLock.InferenceLock.WaitAsync(cancellationTokenSource.Token);
var response = await DoInference(conversation, cancellationTokenSource.Token);
return Ok(response);
}
catch (OperationCanceledException)
{
errorText = $"Inference for {conversationId} was canceled.";
Console.WriteLine(errorText);
return BadRequest(errorText);
return StatusCode(444, errorText);
}
catch (Exception ex)
{
Expand All @@ -95,22 +95,41 @@ public async Task<ActionResult<ChatMessage>> Ask(InferenceRequest chatConversati


var conversationId = chatConversation.Id;
var cancellationTokenSource = new CancellationTokenSource();
ChatCancelation.CancelationTokens[conversationId] = cancellationTokenSource;

var question = chatConversation.ChatMessages.LastOrDefault()?.Message;
if (question == null)
{
return BadRequest("No question provided.");
}

var answer = await LlmMemory.Ask(conversationId.ToString(), question);

var chatMessageAnswer = new ChatMessage()
try
{
Role = ChatMessageRole.Assistant,
Message = answer.Result,
AttachmentIds = answer.RelevantSources.Select(s => s.SourceName).ToList()
};
var answer = await LlmMemory.Ask(conversationId.ToString(), question, cancellationTokenSource.Token);

var chatMessageAnswer = new ChatMessage()
{
Role = ChatMessageRole.Assistant,
Message = answer.Result,
AttachmentIds = answer.RelevantSources.Select(s => s.SourceName).ToList()
};

return chatMessageAnswer;
if (cancellationTokenSource.Token.IsCancellationRequested)
{
throw new OperationCanceledException(cancellationTokenSource.Token);
}

return chatMessageAnswer;
}
catch (OperationCanceledException ex)
{
return StatusCode(444, ex.ToString());
}
catch (Exception ex)
{
return StatusCode(500, ex.ToString());
}
}

[HttpDelete("cancel/{conversationId}", Name = "CancelChat")]
Expand All @@ -132,23 +151,30 @@ public async Task<bool> CancelChat(Guid conversationId)
/// </summary>
/// <param name="conversation">The chat conversation for which to perform inference.</param>
/// <returns>Returns the inference result as a string.</returns>
private async Task<string> DoInference(InferenceRequest conversation)
private async Task<string> DoInference(InferenceRequest conversation, CancellationToken cancellationToken)
{
LLamaContext modelContext = InjectedModel.Model.CreateContext(InjectedModel.ModelParams);
var session = modelContext.CreateChatSession(conversation);
var inferenceParams = conversation.GetInferenceParams(InjectedModel.DefaultAntiPrompts);

var cancelGeneration = new CancellationTokenSource();
var fullResponse = "";
var totalTokens = 0;
var inferenceStopwatch = new Stopwatch();

inferenceStopwatch.Start();
var asyncResponse = session.ChatAsync(session.History,
inferenceParams,
cancelGeneration.Token);
cancellationToken);
await foreach (var text in asyncResponse)
{
if (cancellationToken.IsCancellationRequested)
{
modelContext.Dispose();
inferenceStopwatch.Stop();

throw new OperationCanceledException(cancellationToken);
}

totalTokens++;
fullResponse += text;
}
Expand Down
29 changes: 25 additions & 4 deletions PalmHill.LlmMemory/ServerlessLlmMemory.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.KernelMemory;
using Azure.Core;
using Microsoft.KernelMemory;
using PalmHill.BlazorChat.Shared.Models;
using PalmHill.Llama;
using System.Collections.Concurrent;
Expand Down Expand Up @@ -108,12 +109,32 @@ public async Task<SearchResult> SearchAsync(string conversationId, string query)
return results;
}

public async Task<MemoryAnswer> Ask(string conversationId, string query)
public async Task<MemoryAnswer> Ask(string conversationId, string query, CancellationToken cancellationToken)
{
var processedQuery = processQuery(query);
var results = await KernelMemory.AskAsync(processedQuery, conversationId);
Exception? exception;
try
{
await Llama.ThreadLock.InferenceLock.WaitAsync(cancellationToken);
var results = await KernelMemory.AskAsync(processedQuery, conversationId, cancellationToken: cancellationToken);
return results;
}
catch (OperationCanceledException ex)
{
exception = ex;
}
catch (Exception ex)
{
exception = ex;
}
finally
{
Llama.ThreadLock.InferenceLock.Release();
}

return results;
throw exception;


}

private string processQuery(string query)
Expand Down

0 comments on commit 81a6d54

Please sign in to comment.