Skip to content

Commit

Permalink
Merge pull request #993 from Lyrcaxis/distribution-seed-fix
Browse files Browse the repository at this point in the history
Non-deterministic default seed (+minor sampling parameters names & comments update)
  • Loading branch information
martindevans authored Nov 24, 2024
2 parents 9aff11c + 4772eb7 commit 5bce923
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 22 deletions.
8 changes: 4 additions & 4 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand All @@ -107,8 +107,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand Down
4 changes: 2 additions & 2 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
AlphaPresence = (float)requestSettings.PresencePenalty,
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
PresencePenalty = (float)requestSettings.PresencePenalty,
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
}
};
}
Expand Down
6 changes: 3 additions & 3 deletions LLama/Extensions/LLamaExecutorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ private string CreatePrompt(IList<ChatMessage> messages)
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
SamplingPipeline = new DefaultSamplingPipeline()
{
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty,
PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty,
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
Expand Down
79 changes: 66 additions & 13 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,76 @@ public sealed class DefaultSamplingPipeline
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
[Obsolete($"Use {nameof(FrequencyPenalty)} instead.")]
public float AlphaFrequency
{
get => _alphaFreq;
get => _frequencyPenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be less than 2");
_frequencyPenalty = value;
}
}
private readonly float _alphaFreq;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
[Obsolete($"Use {nameof(PresencePenalty)} instead.")]
public float AlphaPresence
{
get => _alphaPresence;
get => _presencePenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be less than 2");
_presencePenalty = value;
}
}
private readonly float _alphaPresence;

/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float FrequencyPenalty
{
get => _frequencyPenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be less than 2");
_frequencyPenalty = value;
}
}
private readonly float _frequencyPenalty;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float PresencePenalty
{
get => _presencePenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be less than 2");
_presencePenalty = value;
}
}
private readonly float _presencePenalty;

/// <summary>
/// How many tokens should be considered for penalizing repetition
Expand All @@ -71,8 +109,14 @@ public float AlphaPresence
/// <summary>
/// Whether the EOS token should be protected from being modified by penalty
/// </summary>
[Obsolete($"This doesn't do what the name implies. If you're sure you want to use it, use {nameof(PreventEOS)}.")]
public bool PenalizeEOS { get; init; } = false;

/// <summary>
/// Whether the EOS token should be suppressed. Setting this to 'true' prevents EOS from being sampled
/// </summary>
public bool PreventEOS { get; init; } = false;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
/// </summary>
Expand Down Expand Up @@ -111,7 +155,16 @@ public float AlphaPresence
/// <summary>
/// Seed to use for random sampling
/// </summary>
public uint Seed { get; set; } = 42;
public uint Seed { get; set; } = GetRandomSeed();


private static Random RandomSeedGenerator = new();
private static uint GetRandomSeed()
{
lock (RandomSeedGenerator)
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
}


/// <inheritdoc />
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
Expand Down Expand Up @@ -147,8 +200,8 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
context.VocabCount,
context.ModelHandle.Tokens.EOS, context.ModelHandle.Tokens.Newline ?? 0,
RepeatPenaltyCount, RepeatPenalty,
AlphaFrequency, AlphaPresence,
PenalizeNewline, PenalizeEOS
FrequencyPenalty, PresencePenalty,
PenalizeNewline, PreventEOS
);

chain.AddTopK(TopK);
Expand Down

0 comments on commit 5bce923

Please sign in to comment.