-
Notifications
You must be signed in to change notification settings - Fork 14
/
pipeline_flux_controlnet_regional.py
632 lines (530 loc) · 27 KB
/
pipeline_flux_controlnet_regional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers.image_processor import PipelineImageInput
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from diffusers.pipelines.flux.pipeline_flux_controlnet import *
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> device = "cuda"
>>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((1024, 1024))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
>>> images = pipe(
... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
... ).images[0]
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class RegionalFluxAttnProcessor2_0:
def FluxAttnProcessor2_0_call(
self,
attn,
hidden_states,
encoder_hidden_states = None,
attention_mask = None,
image_rotary_emb = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# apply mask on attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
def __call__(
self,
attn,
hidden_states,
hidden_states_base = None,
encoder_hidden_states = None,
encoder_hidden_states_base = None,
attention_mask = None,
image_rotary_emb = None,
image_rotary_emb_base = None,
additional_kwargs = None,
base_ratio = None,
) -> torch.FloatTensor:
if base_ratio is not None:
attn_output_base = self.FluxAttnProcessor2_0_call(
attn=attn,
hidden_states=hidden_states_base if hidden_states_base is not None else hidden_states,
encoder_hidden_states=encoder_hidden_states_base,
attention_mask=None,
image_rotary_emb=image_rotary_emb_base,
)
if encoder_hidden_states_base is not None:
hidden_states_base, encoder_hidden_states_base = attn_output_base
else:
hidden_states_base = attn_output_base
attn_output = self.FluxAttnProcessor2_0_call(
attn=attn,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=additional_kwargs['regional_attention_mask'].to(hidden_states.device) if base_ratio is not None and 'regional_attention_mask' in additional_kwargs else None,
image_rotary_emb=image_rotary_emb,
)
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = attn_output
else:
hidden_states = attn_output
if encoder_hidden_states is not None:
if base_ratio is not None:
# merge hidden_states and hidden_states_base
hidden_states = hidden_states*(1-base_ratio) + hidden_states_base*base_ratio
return hidden_states, encoder_hidden_states, encoder_hidden_states_base
else: # both regional and base input are base prompts, skip the merge
return hidden_states, encoder_hidden_states, encoder_hidden_states
else:
if base_ratio is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : additional_kwargs['encoder_seq_len']],
hidden_states[:, additional_kwargs['encoder_seq_len'] :],
)
encoder_hidden_states_base, hidden_states_base = (
hidden_states_base[:, : additional_kwargs["encoder_seq_len_base"]],
hidden_states_base[:, additional_kwargs["encoder_seq_len_base"] :],
)
# merge hidden_states and hidden_states_base
hidden_states = hidden_states*(1-base_ratio) + hidden_states_base*base_ratio
# concat back
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states_base = torch.cat([encoder_hidden_states_base, hidden_states_base], dim=1)
return hidden_states, hidden_states_base
else: # both regional and base input are base prompts, skip the merge
return hidden_states, hidden_states
class RegionalFluxControlNetPipeline(FluxControlNetPipeline):
@torch.inference_mode()
def __call__(
self,
initial_latent: torch.FloatTensor = None,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
num_samples: int = 1,
width: int = 1024,
height: int = 1024,
strength: float = 1.0,
num_inference_steps: int = 25,
timesteps: List[int] = None,
mask_inject_steps: int = 5,
guidance_scale: float = 5.0,
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
device = self.transformer.device
# 3. Define call parameters
batch_size = num_samples if num_samples else prompt_embeds.shape[0]
# encode base prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=512,
lora_scale=None,
)
# define base mask and inputs
base_mask = torch.ones((height, width), device=device, dtype=self.transformer.dtype) # base mask uses the whole image mask
base_inputs = [(base_mask, prompt_embeds)]
# encode regional prompts, define regional inputs
regional_inputs = []
if 'regional_prompts' in joint_attention_kwargs and 'regional_masks' in joint_attention_kwargs:
for regional_prompt, regional_mask in zip(joint_attention_kwargs['regional_prompts'], joint_attention_kwargs['regional_masks']):
regional_prompt_embeds, regional_pooled_prompt_embeds, regional_text_ids = self.encode_prompt(
prompt=regional_prompt,
prompt_2=regional_prompt,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=512,
lora_scale=None,
)
regional_inputs.append((regional_mask, regional_prompt_embeds))
## prepare masks for regional control
conds = []
masks = []
H, W = height//(self.vae_scale_factor), width//(self.vae_scale_factor)
hidden_seq_len = H * W
for mask, cond in regional_inputs:
if mask is not None: # resize regional masks to image size, the flatten is to match the seq len
mask = torch.nn.functional.interpolate(mask[None, None, :, :], (H, W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, cond.size(1))
else:
mask = torch.ones((H*W, cond.size(1))).to(device=cond.device)
masks.append(mask)
conds.append(cond)
regional_embeds = torch.cat(conds, dim=1)
encoder_seq_len = regional_embeds.shape[1]
# initialize attention mask
regional_attention_mask = torch.zeros(
(encoder_seq_len + hidden_seq_len, encoder_seq_len + hidden_seq_len),
device=masks[0].device,
dtype=torch.bool
)
num_of_regions = len(masks)
each_prompt_seq_len = encoder_seq_len // num_of_regions
# initialize self-attended mask
self_attend_masks = torch.zeros((hidden_seq_len, hidden_seq_len), device=masks[0].device, dtype=torch.bool)
# initialize union mask
union_masks = torch.zeros((hidden_seq_len, hidden_seq_len), device=masks[0].device, dtype=torch.bool)
# handle each mask
for i in range(num_of_regions):
# txt attends to itself
regional_attention_mask[i*each_prompt_seq_len:(i+1)*each_prompt_seq_len, i*each_prompt_seq_len:(i+1)*each_prompt_seq_len] = True
# txt attends to corresponding regional img
regional_attention_mask[i*each_prompt_seq_len:(i+1)*each_prompt_seq_len, encoder_seq_len:] = masks[i].transpose(-1, -2)
# regional img attends to corresponding txt
regional_attention_mask[encoder_seq_len:, i*each_prompt_seq_len:(i+1)*each_prompt_seq_len] = masks[i]
# regional img attends to corresponding regional img
img_size_masks = masks[i][:, :1].repeat(1, hidden_seq_len)
img_size_masks_transpose = img_size_masks.transpose(-1, -2)
self_attend_masks = torch.logical_or(self_attend_masks,
torch.logical_and(img_size_masks, img_size_masks_transpose))
# update union
union_masks = torch.logical_or(union_masks,
torch.logical_or(img_size_masks, img_size_masks_transpose))
background_masks = torch.logical_not(union_masks)
background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks)
regional_attention_mask[encoder_seq_len:, encoder_seq_len:] = background_and_self_attend_masks
## done prepare masks for regional control
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
self.transformer.dtype,
device,
generator,
initial_latent,
)
# prepare control image
if isinstance(self.controlnet, FluxControlNetModel):
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.transformer.dtype,
)
height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
for control_image_ in control_image:
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.transformer.dtype,
)
height, width = control_image_.shape[-2:]
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5.handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
base_ratio = None
if i < mask_inject_steps:
chosen_prompt_embeds = regional_embeds
if i < 1:
base_ratio = joint_attention_kwargs['base_ratio'] #0.1
else:
base_ratio = joint_attention_kwargs['base_ratio']
else:
chosen_prompt_embeds = prompt_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guidance = (
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=None,
return_dict=False,
)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=chosen_prompt_embeds,
encoder_hidden_states_base=prompt_embeds,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
base_ratio=base_ratio,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs={
'single_inject_blocks_interval': joint_attention_kwargs['single_inject_blocks_interval'] if 'single_inject_blocks_interval' in joint_attention_kwargs else len(self.transformer.single_transformer_blocks),
'double_inject_blocks_interval': joint_attention_kwargs['double_inject_blocks_interval'] if 'double_inject_blocks_interval' in joint_attention_kwargs else len(self.transformer.transformer_blocks),
'regional_attention_mask': regional_attention_mask if base_ratio is not None else None,
},
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)