Skip to content

Commit

Permalink
fix samples shape when num_channels > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Oct 22, 2024
1 parent fb4554c commit 24f87b9
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions cirkit/backend/torch/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,18 @@ def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
raise ValueError("The number of samples must be a positive number")

mixture_samples: list[Tensor] = []
# samples: (O, C, K, num_samples, D)
samples = self._circuit.evaluate(
module_fn=functools.partial(
self._layer_fn,
num_samples=num_samples,
mixture_samples=mixture_samples,
),
)
samples = samples[0, 0] # (C, N, D)
samples = samples.permute(1, 0, 2) # (N, C, D)
# samples: (num_samples, O, K, C, D)
samples = samples.permute(3, 0, 2, 1, 4)
# TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
samples = samples[:, 0, 0] # (num_samples, C, D)
return samples, mixture_samples

def _layer_fn(
Expand Down Expand Up @@ -166,6 +169,7 @@ def _pad_samples(self, samples: Tensor, scope_idx: Tensor) -> Tensor:
if scope_idx.shape[1] != 1:
raise NotImplementedError("Padding is only implemented for univariate samples")

# padded_samples: (F, C, K, num_samples, D)
padded_samples = torch.zeros(
(*samples.shape, len(self._circuit.scope)), device=samples.device, dtype=samples.dtype
)
Expand Down

0 comments on commit 24f87b9

Please sign in to comment.