-
Notifications
You must be signed in to change notification settings - Fork 2
/
Utils.cs
435 lines (394 loc) · 17.7 KB
/
Utils.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp.Modules;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using SD;
public static class Utils
{
public static Tensor ApplyRotaryEmbeddings(Tensor input, Tensor freqsComplex)
{
// Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
// Two consecutive values will become a single complex number
// (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
var input_complex = input.to_type(ScalarType.Float32).reshape(input.shape[0], input.shape[1], input.shape[2], -1, 2).view_as_complex();
freqsComplex = freqsComplex.to(input.device);
// Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
// (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
var freqs_complex_reshaped = freqsComplex.unsqueeze(0).unsqueeze(2);
// Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
// Which results in the rotation of the complex number as shown in the Figure 1 of the paper
// (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
var rotated_complex = input_complex * freqs_complex_reshaped;
// Console.WriteLine(rotated_complex.mean().ToSingle());
// Convert the complex number back to the real number
// (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
var rotated = rotated_complex.view_as_real();
// (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
var rotated_reshaped = rotated.reshape(rotated.shape[0], rotated.shape[1], rotated.shape[2], -1);
input.shape.Should().BeEquivalentTo(rotated_reshaped.shape);
return rotated_reshaped.type_as(input);
}
public static Module<Tensor, Tensor> GetActivation(string act_fn)
{
return act_fn switch
{
"silu" => nn.SiLU(),
"relu" => nn.ReLU(),
"gelu" => nn.GELU(),
"tanh" => nn.Tanh(),
"swish" => nn.SiLU(),
_ => throw new ArgumentException("Invalid activation function", nameof(act_fn)),
};
}
/// <summary>
/// Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
/// </summary>
/// <param name="betas">the betas that the scheduler is being initialized with.</param>
/// <returns></returns>
public static Tensor RescaleZeroTerminalSnr(Tensor betas)
{
var alphas = 1.0f - betas;
var alphas_cumprod = torch.cumprod(alphas, 0);
var alphas_bar_sqrt = alphas_cumprod.sqrt();
// store old values
var alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone();
var alphas_bar_sqrt_T = alphas_bar_sqrt[^1].clone();
// shift so the last timestep is zero
alphas_bar_sqrt = alphas_bar_sqrt - alphas_bar_sqrt_T;
// scale so the first timestep is back to the old value
alphas_bar_sqrt = alphas_bar_sqrt * alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T);
// convert alphas_bar_sqrt to betas
var alphas_bar = alphas_bar_sqrt.pow(2);
alphas = alphas_bar[1..] / alphas_bar[..^1];
alphas = torch.cat([alphas_bar[0..1], alphas]);
betas = 1.0f - alphas;
return betas;
}
/// <summary>
/// This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
/// </summary>
/// <param name="timesteps">a 1-D Tensor of N indices, one per batch element. These may be fractional</param>
/// <param name="embedding_dim">the dimension of the output.</param>
/// <param name="flip_sin_to_cos"></param>
/// <param name="downscale_freq_shift"></param>
/// <param name="scale"></param>
/// <param name="max_period">controls the minimum frequency of the embeddings.</param>
/// <returns>an [N x dim] Tensor of positional embeddings.</returns>
public static Tensor GetTimestepEmbedding(
Tensor timesteps,
int embedding_dim,
bool flip_sin_to_cos = false,
float downscale_freq_shift = 1,
float scale = 1f,
int max_period = 10000)
{
var half_dim = embedding_dim / 2;
var exponent = -Math.Log(max_period) * torch.arange(0, half_dim, device: timesteps.device, dtype: ScalarType.Float32);
exponent = exponent / (half_dim - downscale_freq_shift);
var emb = torch.exp(exponent);
emb = timesteps.unsqueeze(1).to(ScalarType.Float32) * emb.unsqueeze(0);
emb = scale * emb;
emb = torch.cat([emb.sin(), emb.cos()], dim: 1);
if (flip_sin_to_cos)
{
emb = torch.cat([emb[.., half_dim..], emb[.., ..half_dim]], dim: -1);
}
if (embedding_dim % 2 == 1)
{
emb = nn.functional.pad(emb, [0, 1, 0, 0]);
}
return emb;
}
public static Module<DownBlock2DInput, DownBlock2DOutput> GetDownBlock(
string down_block_type,
int num_layers,
int in_channels,
int out_channels,
int temb_channels,
bool add_downsample,
float resnet_eps,
string resnet_act_fn,
int transformer_layers_per_block = 1,
int? num_attention_heads = null,
int? resnet_groups = null,
int? cross_attention_dim = null,
int? downsample_padding = null,
bool dual_cross_attention = false,
bool use_linear_projection = false,
bool only_cross_attention = false,
bool upcast_attention = false,
string resnet_time_scale_shift = "default",
string attention_type = "default",
bool resnet_skip_time_act = false,
float resnet_out_scale_factor = 1.0f,
string? cross_attention_norm = null,
int? attention_head_dim = null,
string? downsample_type = null,
float dropout = 0.0f,
ScalarType dtype = ScalarType.Float32)
{
// If attn head dim is not defined, we default it to the number of heads
attention_head_dim ??= num_attention_heads;
down_block_type = down_block_type.StartsWith("UNetRes") ? down_block_type.Substring(7) : down_block_type;
if (down_block_type == nameof(DownBlock2D))
{
return new DownBlock2D(
num_layers: num_layers,
in_channels: in_channels,
out_channels: out_channels,
temb_channels: temb_channels,
dropout: dropout,
add_downsample: add_downsample,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
resnet_groups: resnet_groups,
downsample_padding: downsample_padding,
resnet_time_scale_shift: resnet_time_scale_shift,
dtype: dtype);
}
else if (down_block_type == nameof(CrossAttnDownBlock2D))
{
if (cross_attention_dim is null)
{
throw new ArgumentException("Cross attention dimension must be defined for CrossAttnDownBlock2D", nameof(cross_attention_dim));
}
return new CrossAttnDownBlock2D(
num_layers: num_layers,
// transformer_layers_per_block: transformer_layers_per_block,
in_channels: in_channels,
out_channels: out_channels,
temb_channels: temb_channels,
dropout: dropout,
add_downsample: add_downsample,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
resnet_groups: resnet_groups,
downsample_padding: downsample_padding,
cross_attention_dim: cross_attention_dim,
num_attention_heads: num_attention_heads,
dual_cross_attention: dual_cross_attention,
use_linear_projection: use_linear_projection,
only_cross_attention: only_cross_attention,
upcast_attention: upcast_attention,
resnet_time_scale_shift: resnet_time_scale_shift,
attention_type: attention_type,
dtype: dtype);
}
else
{
throw new ArgumentException("Invalid down block type", nameof(down_block_type));
};
}
public static Module<UNetMidBlock2DInput, Tensor> GetMidBlock(
string mid_block_type,
int temb_channels,
int in_channels,
float resnet_eps,
string resnet_act_fn,
int resnet_groups,
float output_scale_factor = 1.0f,
int transformer_layers_per_block = 1,
int? num_attention_heads = null,
int? cross_attention_dim = null,
bool dual_cross_attention = false,
bool use_linear_projection = false,
bool mid_block_only_cross_attention = false,
bool upcast_attention = false,
string resnet_time_scale_shift = "default",
string attention_type = "default",
bool resnet_skip_time_act = false,
string? cross_attention_norm = null,
int? attention_head_dim = 1,
float dropout = 0.0f,
ScalarType dtype = ScalarType.Float32)
{
if (mid_block_type == nameof(UNetMidBlock2D))
{
return new UNetMidBlock2D(
in_channels: in_channels,
temb_channels: temb_channels,
dropout: dropout,
num_layers: 0,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
output_scale_factor: output_scale_factor,
resnet_groups: resnet_groups,
resnet_time_scale_shift: resnet_time_scale_shift,
add_attention: false,
dtype: dtype);
}
else if (mid_block_type == nameof(UNetMidBlock2DCrossAttn))
{
var transformer_layers_per_block_list = Enumerable.Repeat(transformer_layers_per_block, 1).ToArray();
return new UNetMidBlock2DCrossAttn(
transformer_layers_per_block: transformer_layers_per_block_list,
in_channels: in_channels,
temb_channels: temb_channels,
dropout: dropout,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
output_scale_factor: output_scale_factor,
resnet_time_scale_shift: resnet_time_scale_shift,
cross_attention_dim: cross_attention_dim,
num_attention_heads: num_attention_heads ?? 1,
resnet_groups: resnet_groups,
dual_cross_attention: dual_cross_attention,
use_linear_projection: use_linear_projection,
upcast_attention: upcast_attention,
attention_type: attention_type,
dtype: dtype);
}
else
{
throw new ArgumentException("Invalid mid block type", nameof(mid_block_type));
}
}
public static Module<UpBlock2DInput, Tensor> GetUpBlock(
string up_block_type,
int num_layers,
int in_channels,
int out_channels,
int prev_output_channel,
int temb_channels,
bool add_upsample,
float resnet_eps,
string resnet_act_fn,
int? resolution_idx = null,
int transformer_layers_per_block = 1,
int num_attention_heads = 1,
int resnet_groups = 32,
int cross_attention_dim = 1280,
bool dual_cross_attention = false,
bool use_linear_projection = false,
bool only_cross_attention = false,
bool upcast_attention = false,
string resnet_time_scale_shift = "default",
string attention_type = "default",
bool resnet_skip_time_act = false,
float resnet_out_scale_factor = 1.0f,
string? cross_attention_norm = null,
int? attention_head_dim = null,
string? upsample_type = null,
float dropout = 0.0f,
ScalarType dtype = ScalarType.Float32)
{
attention_head_dim = attention_head_dim ?? num_attention_heads;
up_block_type = up_block_type.StartsWith("UNetRes") ? up_block_type.Substring(7) : up_block_type;
if (up_block_type == nameof(UpBlock2D))
{
return new UpBlock2D(
num_layers: num_layers,
in_channels: in_channels,
out_channels: out_channels,
prev_output_channel: prev_output_channel,
temb_channels: temb_channels,
resolution_idx: resolution_idx,
dropout: dropout,
add_upsample: add_upsample,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
resnet_groups: resnet_groups,
resnet_time_scale_shift: resnet_time_scale_shift,
dtype: dtype);
}
else if (up_block_type == nameof(CrossAttnUpBlock2D))
{
var transformer_layers_per_block_list = Enumerable.Repeat(1, num_layers).ToArray();
return new CrossAttnUpBlock2D(
num_layers: num_layers,
transformer_layers_per_block: transformer_layers_per_block_list,
in_channels: in_channels,
out_channels: out_channels,
prev_output_channel: prev_output_channel,
temb_channels: temb_channels,
resolution_idx: resolution_idx,
dropout: dropout,
add_upsample: add_upsample,
resnet_eps: resnet_eps,
resnet_act_fn: resnet_act_fn,
resnet_groups: resnet_groups,
cross_attention_dim: cross_attention_dim,
num_attention_heads: num_attention_heads,
dual_cross_attention: dual_cross_attention,
use_linear_projection: use_linear_projection,
only_cross_attention: only_cross_attention,
upcast_attention: upcast_attention,
resnet_time_scale_shift: resnet_time_scale_shift,
attention_type: attention_type,
dtype: dtype);
}
else
{
throw new ArgumentException("Invalid up block type", nameof(up_block_type));
}
}
public static Tensor PrecomputeThetaPosFrequencies(int headDim, int seqLen, string device, float theta = 10000.0f)
{
// As written in the paragraph 3.2.2 of the paper
// >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
if (headDim % 2 != 0)
{
throw new ArgumentException("Dimension must be divisible by 2", nameof(headDim));
}
// Build the theta parameter
// According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
// Shape: (Head_Dim / 2)
var thetaNumerator = torch.arange(0, headDim, 2).to(torch.float32).to(device);
// Shape: (Head_Dim / 2)
var thetaInput = torch.pow(theta, -1.0f * (thetaNumerator / headDim)).to(device); // (Dim / 2)
// Construct the positions (the "m" parameter)
// Shape: (Seq_Len)
var m = torch.arange(seqLen, device: device);
// Multiply each theta by each position using the outer product.
// Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
var freqs = torch.outer(m, thetaInput).to(torch.float32).to(device);
// We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
// (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
var freqsComplex = torch.polar(torch.ones_like(freqs), freqs);
return freqsComplex;
}
public static Tensor RotateHalf(Tensor x)
{
var x1 = x[.., .., .., ..(int)(x.shape[^1] / 2)];
var x2 = x[.., .., .., (int)(x.shape[^1] / 2)..];
// (x1 * x1 * x2).Peek("x1 * x1 * x2");
return torch.cat([-x2, x1], dim: -1);
}
public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor positionIds, int unsqueezeDim = 1)
{
// The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
// sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
// that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
// k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
// cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
// the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
cos = cos[positionIds].unsqueeze(unsqueezeDim);
sin = sin[positionIds].unsqueeze(unsqueezeDim);
var qEmbed = q * cos;
qEmbed += RotateHalf(q) * sin;
var kEmbed = k * cos;
kEmbed += RotateHalf(k) * sin;
// var kEmbed = (k * cos) + (RotateHalf(k) * sin);
return (qEmbed, kEmbed);
}
public static Tensor RepeatKV(Tensor x, int nRep)
{
var batchSize = x.shape[0];
var seqLen = x.shape[1];
var nKVHeads = x.shape[2];
var headDim = x.shape[3];
if (nRep == 1)
{
return x;
}
return x.unsqueeze(3)
.expand(batchSize, seqLen, nKVHeads, nRep, headDim)
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
}
}