Skip to content

Commit

Permalink
Cancel running inference from ui
Browse files Browse the repository at this point in the history
  • Loading branch information
edgett committed Dec 12, 2023
1 parent e738739 commit a828b2e
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 80 deletions.
6 changes: 5 additions & 1 deletion PalmHill.BlazorChat.ApiClient/WebApiInterface/IChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ public interface IChat
[Post("/api/chat")]
Task<ApiResponse<string>> Chat(InferenceRequest conversation);

[Post("/api/attachment/ask")]
[Post("/api/chat/docs")]
Task<ApiResponse<ChatMessage>> Ask(InferenceRequest chatConversation);

[Delete("/api/chat/cancel/{conversationId}")]
public Task<ApiResponse<bool>> CancelChat(Guid conversationId);

}
}
68 changes: 28 additions & 40 deletions PalmHill.BlazorChat/Client/Components/Chat/ChatInput.razor
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,16 @@
<FluentFooter Class="w-100" Style="position:fixed; bottom:15px;">
<FluentStack Class="chat-messages input-area">
<div class="file-input-area" @onclick="ShowAttachmentPanel">
@*Not wrapping this in an if block. Use css to hide. FluentInputFile won't work if the element disappears from the DOM. *@

@{
var isProgressCssClass = ((progressPercent == 0 || progressPercent == 100) ? string.Empty : "hidden");
}

@if (progressPercent > 0 && progressPercent < 100)
{
<FluentProgressRing Style="height:32px; margin-top:6px" />
}


<FluentCounterBadge
Count="attachments.Count"
Count="Controller?.UploadedFiles.Count ?? 0"
ShowZero="true"
BottomPosition="-10"
Class="@isProgressCssClass"

>
<FluentButton
Id="uploadButton"

Class="@isProgressCssClass"
IconStart="(new Icons.Regular.Size24.Attach())"
Style="margin-top:6px;"
></FluentButton>
Expand All @@ -51,12 +40,25 @@
></textarea>
</div>
<div>
<FluentButton @ref="sendButton"
OnClick="Send"
Loading="Loading"
Appearance="Appearance.Accent"
IconStart="(new Icons.Regular.Size24.Send())"
Style="margin-top:6px;"></FluentButton>
@if (Controller?.CanSend == true)
{
<FluentButton OnClick="Send"
Disabled="string.IsNullOrWhiteSpace(Controller!.UserInput)"
Appearance="Appearance.Accent"
IconStart="(new Icons.Regular.Size24.Send())"
Style="margin-top:6px;"></FluentButton>

}

@if (Controller?.CanStop == true)
{
<FluentButton OnClick="CancelTextGeneration"
Appearance="Appearance.Accent"
IconStart="(new Icons.Regular.Size24.Stop())"
Style="margin-top:6px;"></FluentButton>

}

</div>
</FluentStack>
</FluentFooter>
Expand All @@ -68,16 +70,7 @@


private ElementReference textAreaElement;
private string messageInput = string.Empty;
private FluentButton? sendButton;
private int progressPercent = 0;
private string progressTitle = string.Empty;
private ConcurrentDictionary<string, AttachmentInfo> attachments = new ConcurrentDictionary<string, AttachmentInfo>();
public UISettings UISettings { get; set; } = new UISettings();
public List<AttachmentInfo> SelectedFiles = new List<AttachmentInfo>();
public List<AttachmentInfo> UploadedFiles = new List<AttachmentInfo>();
public bool Loading { get; set; } = false;
public bool IsAttachmentMode { get; set; } = false;




Expand All @@ -89,6 +82,11 @@

}

public async Task CancelTextGeneration()
{
await Controller!.CancelTextGeneration();
}

private async Task HandleKeyPress(KeyboardEventArgs e)
{
if (e.Key == "Enter" && !e.ShiftKey)
Expand All @@ -109,15 +107,5 @@
Controller!.ShowAttachments();
}

public void SetReady()
{
Loading = false;
StateHasChanged();
}

public void SetAttachmentMode(bool enabled)
{
IsAttachmentMode = enabled;
StateHasChanged();
}
}
28 changes: 27 additions & 1 deletion PalmHill.BlazorChat/Client/Services/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ BlazorChatApi blazorChatApi
}

public string UserInput { get; set; } = string.Empty;
public bool CanSend { get; set; } = true;
public bool CanStop { get; set; } = false;
public bool AttachmentsEnabled { get; set; } = false;
public bool AttachmentsVisible { get; private set; } = false;
public List<AttachmentInfo> SelectedFiles = new List<AttachmentInfo>();
Expand Down Expand Up @@ -73,6 +75,10 @@ public async Task SendPrompt()

public async Task SendToWebSocketChat()
{
//Set the UI state.
CanSend = false;
CanStop = true;

var prompt = new WebSocketChatMessage();
prompt.Prompt = UserInput;
WebsocketChatMessages.Add(prompt);
Expand Down Expand Up @@ -165,13 +171,33 @@ public void HideAttachments()
}


public void SetReady()
{
CanSend = true;
CanStop = false;
StateHasChanged();
}

public async Task CancelTextGeneration()
{
var canceled = await _blazorChatApi!.Chat.CancelChat(WebSocketChatConnection.ConversationId);

if (canceled.Content)
{
SetReady();
}

Console.WriteLine($"CancelTextGeneration failed ({canceled.StatusCode}): {canceled.ReasonPhrase}");
}

private void setupWebSocketChatConnection()
{
WebSocketChatConnection.OnInferenceStatusUpdate += (sender, inferenceStatusUpdate) =>
{
StateHasChanged();
if (inferenceStatusUpdate.IsComplete == true)
{
SetReady();
}
};

WebSocketChatConnection.OnAttachmentStatusUpdate += (sender, attachmentInfo) =>
Expand Down
8 changes: 7 additions & 1 deletion PalmHill.BlazorChat/Client/Services/WebSocketChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public WebSocketChatService(Uri chatHubUri, List<WebSocketChatMessage> webSocket
public string SystemMessage = string.Empty;
public List<WebSocketChatMessage> WebSocketChatMessages { get; }
public HubConnection HubConnection { get; }
public bool CanSend { get; set; } = true;

public event EventHandler<WebSocketInferenceString>? OnReceiveInferenceString;
public event EventHandler<WebSocketInferenceStatusUpdate>? OnInferenceStatusUpdate;
Expand All @@ -29,13 +30,18 @@ public WebSocketChatService(Uri chatHubUri, List<WebSocketChatMessage> webSocket
public async Task StartAsync()
{
await HubConnection.StartAsync();
}
}

public async Task StopAsync()
{
await HubConnection.StopAsync();
}

public async Task CancelInferenceAsync()
{
await HubConnection.SendAsync("CancelInference", ConversationId);
}

public async Task SendInferenceRequestAsync()
{
var inferenceRequest = GetInferenceRequestFromWebsocketMessages();
Expand Down
9 changes: 9 additions & 0 deletions PalmHill.BlazorChat/Server/ChatCancelation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using System.Collections.Concurrent;

namespace PalmHill.BlazorChat.Server
{
public static class ChatCancelation
{
public static ConcurrentDictionary<Guid, CancellationTokenSource> CancelationTokens = new ConcurrentDictionary<Guid, CancellationTokenSource>();
}
}
50 changes: 37 additions & 13 deletions PalmHill.BlazorChat/Server/SignalR/WebSocketChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using PalmHill.Llama.Models;
using PalmHill.LlmMemory;
using PalmHill.BlazorChat.Shared.Models.WebSocket;
using System.Collections.Concurrent;

namespace PalmHill.BlazorChat.Server.SignalR
{
Expand All @@ -25,6 +26,7 @@ public WebSocketChat(InjectedModel injectedModel, LlmMemory.ServerlessLlmMemory?
}
private InjectedModel InjectedModel { get; }
private ServerlessLlmMemory? LlmMemory { get; }


/// <summary>
/// Sends a chat prompt to the client and waits for a response. The method performs inference on the chat conversation and sends the result back to the client.
Expand All @@ -35,11 +37,31 @@ public WebSocketChat(InjectedModel injectedModel, LlmMemory.ServerlessLlmMemory?
/// <exception cref="Exception">Thrown when an error occurs during the inference process.</exception>
public async Task InferenceRequest(InferenceRequest chatConversation)
{
var conversationId = chatConversation.Id;
var cancellationTokenSource = new CancellationTokenSource();
ChatCancelation.CancelationTokens[conversationId] = cancellationTokenSource;

await ThreadLock.InferenceLock.WaitAsync();
try
{
await DoInferenceAndRespondToClient(Clients.Caller, chatConversation);
await ThreadLock.InferenceLock.WaitAsync(cancellationTokenSource.Token);
await DoInferenceAndRespondToClient(Clients.Caller, chatConversation, cancellationTokenSource.Token);

var inferenceStatusUpdate = new WebSocketInferenceStatusUpdate();
inferenceStatusUpdate.MessageId = chatConversation.ChatMessages.LastOrDefault()?.Id;
inferenceStatusUpdate.IsComplete = true;
inferenceStatusUpdate.Success = true;
await Clients.Caller.SendAsync("InferenceStatusUpdate", inferenceStatusUpdate);
}

catch (OperationCanceledException)
{
var inferenceStatusUpdate = new WebSocketInferenceStatusUpdate();
inferenceStatusUpdate.MessageId = chatConversation.ChatMessages.LastOrDefault()?.Id;
inferenceStatusUpdate.IsComplete = true;
inferenceStatusUpdate.Success = false;
await Clients.Caller.SendAsync("InferenceStatusUpdate", inferenceStatusUpdate);
// Handle the cancellation operation
Console.WriteLine($"Inference for {conversationId} was canceled.");
}
catch (Exception ex)
{
Expand All @@ -48,11 +70,8 @@ public async Task InferenceRequest(InferenceRequest chatConversation)
finally
{
ThreadLock.InferenceLock.Release();
var inferenceStatusUpdate = new WebSocketInferenceStatusUpdate();
inferenceStatusUpdate.MessageId = chatConversation.ChatMessages.LastOrDefault()?.Id;
inferenceStatusUpdate.IsComplete = true;
inferenceStatusUpdate.Success = true;
await Clients.Caller.SendAsync("InferenceStatusUpdate", inferenceStatusUpdate);
ChatCancelation.CancelationTokens.TryRemove(conversationId, out _);

}
}

Expand All @@ -64,16 +83,15 @@ public async Task InferenceRequest(InferenceRequest chatConversation)
/// <param name="messageId">The unique identifier for the message.</param>
/// <param name="chatConversation">The chat conversation to use for inference.</param>
/// <returns>A Task that represents the asynchronous operation.</returns>
private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToClient, InferenceRequest chatConversation)
private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToClient, InferenceRequest chatConversation, CancellationToken cancellationToken)
{

// Create a context for the model and a chat session for the conversation
LLamaContext modelContext = InjectedModel.Model.CreateContext(InjectedModel.ModelParams);
var session = modelContext.CreateChatSession(chatConversation);
var inferenceParams = chatConversation.GetInferenceParams(InjectedModel.DefaultAntiPrompts);

var messageId = chatConversation.ChatMessages.LastOrDefault()?.Id;
var cancelGeneration = new CancellationTokenSource();

var textBuffer = "";
var fullResponse = "";
var totalTokens = 0;
Expand All @@ -82,10 +100,17 @@ private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToCli
inferenceStopwatch.Start();
var asyncResponse = session.ChatAsync(session.History,
inferenceParams,
cancelGeneration.Token);
cancellationToken);
// Perform inference and send the response to the client
await foreach (var text in asyncResponse)
{
if (cancellationToken.IsCancellationRequested)
{
modelContext.Dispose();
inferenceStopwatch.Stop();

throw new OperationCanceledException();
}

totalTokens++;
fullResponse += text;
Expand All @@ -104,7 +129,7 @@ private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToCli


}
modelContext.Dispose();
modelContext.Dispose();

inferenceStopwatch.Stop();

Expand All @@ -113,7 +138,6 @@ private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToCli
await respondToClient.SendAsync("ReceiveInferenceString", chatConversation.Id, textBuffer);
}

await respondToClient.SendAsync("MessageComplete", chatConversation.Id, "success");
Console.WriteLine($"Inference took {inferenceStopwatch.ElapsedMilliseconds}ms and generated {totalTokens} tokens. {(totalTokens / (inferenceStopwatch.ElapsedMilliseconds / (float)1000)).ToString("F2")} tokens/second.");
Console.WriteLine(fullResponse);
}
Expand Down
Loading

0 comments on commit a828b2e

Please sign in to comment.