Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false #31

Open
adammenges opened this issue Feb 8, 2024 · 0 comments

Comments

@adammenges
Copy link

Got the following error when trying to use the Notebook (as is, no modifications). 5th cell, the one running pipe(...)

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Any ideas?

Full trace below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 images = pipe(prompt,negative_prompt,
      2               batch_size = 2, #batch size
      3               num_inference_steps=30, # sampling step
      4               height = 896, 
      5               width = 640, 
      6               end_steps = 1, # The number of steps to end the attention double version (specified in a ratio of 0-1. If it is 1, attention double version will be applied in all steps, with 0 being the normal generation)
      7               base_ratio=0.2, # Base ratio, the weight of base prompt, if 0, all are regional prompts, if 1, all are base prompts
      8               seed = 4396, # random seed
      9 )

Cell In[1], line 108, in RegionalGenerator.__call__(self, prompts, negative_prompt, batch_size, height, width, guidance_scale, num_inference_steps, seed, base_ratio, end_steps)
    106 #predict noise
    107 with torch.no_grad():
--> 108     noise_pred = self.unet(sample = latent_model_input,timestep = t,encoder_hidden_states=text_embs).sample
    110 #negative CFG
    111 noise_pred_text, noise_pred_negative= noise_pred.chunk(2)

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py:905, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
    903 for downsample_block in self.down_blocks:
    904     if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 905         sample, res_samples = downsample_block(
    906             hidden_states=sample,
    907             temb=emb,
    908             encoder_hidden_states=encoder_hidden_states,
    909             attention_mask=attention_mask,
    910             cross_attention_kwargs=cross_attention_kwargs,
    911             encoder_attention_mask=encoder_attention_mask,
    912         )
    913     else:
    914         sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:993, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask)
    991     else:
    992         hidden_states = resnet(hidden_states, temb)
--> 993         hidden_states = attn(
    994             hidden_states,
    995             encoder_hidden_states=encoder_hidden_states,
    996             cross_attention_kwargs=cross_attention_kwargs,
    997             attention_mask=attention_mask,
    998             encoder_attention_mask=encoder_attention_mask,
    999             return_dict=False,
   1000         )[0]
   1002     output_states = output_states + (hidden_states,)
   1004 if self.downsamplers is not None:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/transformer_2d.py:291, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    289 # 2. Blocks
    290 for block in self.transformer_blocks:
--> 291     hidden_states = block(
    292         hidden_states,
    293         attention_mask=attention_mask,
    294         encoder_hidden_states=encoder_hidden_states,
    295         encoder_attention_mask=encoder_attention_mask,
    296         timestep=timestep,
    297         cross_attention_kwargs=cross_attention_kwargs,
    298         class_labels=class_labels,
    299     )
    301 # 3. Output
    302 if self.is_input_continuous:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention.py:170, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
    165 if self.attn2 is not None:
    166     norm_hidden_states = (
    167         self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
    168     )
--> 170     attn_output = self.attn2(
    171         norm_hidden_states,
    172         encoder_hidden_states=encoder_hidden_states,
    173         attention_mask=encoder_attention_mask,
    174         **cross_attention_kwargs,
    175     )
    176     hidden_states = attn_output + hidden_states
    178 # 3. Feed-forward

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:321, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    317 def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
    318     # The `Attention` class can call different attention processors / attention functions
    319     # here we simply pass along all tensors to the selected processor class
    320     # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 321     return self.processor(
    322         self,
    323         hidden_states,
    324         encoder_hidden_states=encoder_hidden_states,
    325         attention_mask=attention_mask,
    326         **cross_attention_kwargs,
    327     )

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:1046, in XFormersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
   1043 key = attn.head_to_batch_dim(key).contiguous()
   1044 value = attn.head_to_batch_dim(value).contiguous()
-> 1046 hidden_states = xformers.ops.memory_efficient_attention(
   1047     query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
   1048 )
   1049 hidden_states = hidden_states.to(query.dtype)
   1050 hidden_states = attn.batch_to_head_dim(hidden_states)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:197, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op)
    117 def memory_efficient_attention(
    118     query: torch.Tensor,
    119     key: torch.Tensor,
   (...)
    125     op: Optional[AttentionOp] = None,
    126 ) -> torch.Tensor:
    127     """Implements the memory-efficient attention mechanism following
    128     `"Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>`_](http://arxiv.org/abs/2112.05682%3E%60_).
    129 
   (...)
    195     :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
    196     """
--> 197     return _memory_efficient_attention(
    198         Inputs(
    199             query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
    200         ),
    201         op=op,
    202     )

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:293, in _memory_efficient_attention(inp, op)
    288 def _memory_efficient_attention(
    289     inp: Inputs, op: Optional[AttentionOp] = None
    290 ) -> torch.Tensor:
    291     # fast-path that doesn't require computing the logsumexp for backward computation
    292     if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
--> 293         return _memory_efficient_attention_forward(
    294             inp, op=op[0] if op is not None else None
    295         )
    297     output_shape = inp.normalize_bmhk()
    298     return _fMHA.apply(
    299         op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
    300     ).reshape(output_shape)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:313, in _memory_efficient_attention_forward(inp, op)
    310 else:
    311     _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
--> 313 out, *_ = op.apply(inp, needs_gradient=False)
    314 return out.reshape(output_shape)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/cutlass.py:106, in FwOp.apply(cls, inp, needs_gradient)
    104 causal = isinstance(inp.attn_bias, LowerTriangularMask)
    105 cu_seqlen_k, cu_seqlen_q, max_seqlen_q = _get_seqlen_info(inp)
--> 106 out, lse = cls.OPERATOR(
    107     query=inp.query,
    108     key=inp.key,
    109     value=inp.value,
    110     cu_seqlens_q=cu_seqlen_q,
    111     cu_seqlens_k=cu_seqlen_k,
    112     max_seqlen_q=max_seqlen_q,
    113     compute_logsumexp=needs_gradient,
    114     causal=causal,
    115     scale=inp.scale,
    116 )
    117 ctx: Optional[Context] = None
    118 if needs_gradient:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/_ops.py:442, in OpOverloadPacket.__call__(self, *args, **kwargs)
    437 def __call__(self, *args, **kwargs):
    438     # overloading __call__ to ensure torch.ops.foo.bar()
    439     # is still callable from JIT
    440     # We save the function ptr as the `op` attribute on
    441     # OpOverloadPacket to access it here.
--> 442     return self._op(*args, **kwargs or {})

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant