Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched marginalisation mask #318

Merged
merged 9 commits into from
Nov 25, 2024
129 changes: 115 additions & 14 deletions cirkit/backend/torch/queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from abc import ABC
from collections.abc import Iterable

import torch
from torch import Tensor
Expand All @@ -17,7 +18,16 @@ def __init__(self) -> None:


class IntegrateQuery(Query):
"""The integration query object."""
"""The integration query object allows marginalising out variables.

Computes output in two forward passes:
a) The normal circuit forward pass for input x
b) The integration forward pass where all variables are marginalised

A mask over random variables is computed based on the scopes passed as
input. This determines whether the integrated or normal circuit result
is returned for each variable.
"""

def __init__(self, circuit: TorchCircuit) -> None:
"""Initialize an integration query object.
Expand All @@ -36,33 +46,72 @@ def __init__(self, circuit: TorchCircuit) -> None:
super().__init__()
self._circuit = circuit

def __call__(self, x: Tensor, *, integrate_vars: Scope) -> Tensor:
def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope]) -> Tensor:
"""Solve an integration query, given an input batch and the variables to integrate.

Args:
x: An input batch of shape (B, C, D), where B is the batch size, C is the number of
channels per variable, and D is the number of variables.
integrate_vars: The variables to integrate. It must be a subset of the variables on
which the circuit given in the constructor is defined on.

The format can be one of the following three:
1. Tensor of shape (B, D) where B is the batch size and D is the number of
variables in the scope of the circuit. Its dtype should be torch.bool
and have True in the positions of random variables that should be
marginalised out and False elsewhere.
2. Scope, in this case the same integration mask is applied for all entries
of the batch
3. List of Scopes, where the length of the list must be either 1 or B. If
the list has length 1, behaves as above.
Returns:
The result of the integration query, given as a tensor of shape (B, O, K),
where B is the batch size, O is the number of output vectors of the circuit, and
K is the number of units in each output vector.
"""
if not integrate_vars <= self._circuit.scope:
raise ValueError("The variables to marginalize must be a subset of the circuit scope")
integrate_vars_idx = torch.tensor(tuple(integrate_vars), device=self._circuit.device)
if isinstance(integrate_vars, Tensor):
# Check type of tensor is boolean
if integrate_vars.dtype != torch.bool:
raise ValueError(
"Expected dtype of tensor to be torch.bool, got %s" % integrate_vars.dtype
)
# If single dimensional tensor, assume batch size = 1
if len(integrate_vars.shape) == 1:
integrate_vars = torch.unsqueeze(integrate_vars, 0)
# If the scope is correct, proceed, otherwise error
num_vars = max(self._circuit.scope) + 1
if integrate_vars.shape[1] == num_vars:
integrate_vars_mask = integrate_vars
else:
raise ValueError(
"Circuit scope has %d variables but integrate_vars"
" was defined over %d != %d variables."
% (num_vars, integrate_vars.shape[1], num_vars)
)
else:
# Convert list of scopes to a boolean mask of dimension (B, N) where
# N is the number of variables in the circuit's scope.
integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars)

# Check batch sizes of input x and mask are compatible
if integrate_vars_mask.shape[0] not in (1, x.shape[0]):
raise ValueError(
"The number of scopes to integrate over must"
" either match the batch size of x, or be 1 if you"
" want to broadcast."
" Found #inputs = %d != %d = len(integrate_vars)"
% (x.shape[0], integrate_vars_mask.shape[0])
)

output = self._circuit.evaluate(
x,
module_fn=functools.partial(
IntegrateQuery._layer_fn, integrate_vars_idx=integrate_vars_idx
IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask
),
) # (O, B, K)
return output.transpose(0, 1) # (B, O, K)

@staticmethod
def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_idx: Tensor) -> Tensor:
def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> Tensor:
# Evaluate a layer: if it is not an input layer, then evaluate it in the usual
# feed-forward way. Otherwise, use the variables to integrate to solve the marginal
# queries on the input layers.
Expand All @@ -71,21 +120,73 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_idx: Tensor) -> Te
return output
if layer.num_variables > 1:
raise NotImplementedError("Integration of multivariate input layers is not supported")
# integration_mask: Boolean mask of shape (F, 1)
integration_mask = torch.isin(layer.scope_idx, integrate_vars_idx)
# integrate_vars_mask is a boolean tensor of dim (B, N)
# where N is the number of variables in the scope of the whole circuit.
#
# layer.scope_idx contains a subset of the variable_idxs of the scope
# but may be a reshaped tensor; the shape and order of the variables may be different.
#
# as such, we need to use the idxs in layer.scope_idx to lookup the values from
# the integrate_vars_mask - this will return the correct shape and values.
#
# if integrate_vars_mask was a vector, we could do integrate_vars_mask[layer.scope_idx]
# the vmap below applies the above across the B dimension

# integration_mask has dimension (B, F, Ko)
integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask)
# permute to match integration_output: integration_mask has dimension (F, B, Ko)
integration_mask = integration_mask.permute([1, 0, 2])

if not torch.any(integration_mask).item():
return output
# output: output of the layer of shape (F, B, Ko)
# integration_mask: Boolean mask of shape (F, 1, 1)
# integration_output: result of the integration of the layer of shape (F, 1, Ko)
integration_mask = integration_mask.unsqueeze(dim=2)

integration_output = layer.integrate()
# Use the integration mask to select which output should be the result of
# an integration operation, and which should not be
# This is done in parallel for all folds, and regardless of whether the
# circuit is folded or unfolded
return torch.where(integration_mask, integration_output, output)

@staticmethod
def scopes_to_mask(circuit, batch_integrate_vars: [Scope]):
"""Accepts a batch of scopes and returns a boolean mask as a tensor with
True in positions of specified scope indices and False otherwise.
"""
# If we passed a single scope, assume B = 1
if isinstance(batch_integrate_vars, Scope):
batch_integrate_vars = [batch_integrate_vars]

batch_size = len(tuple(batch_integrate_vars))
# There are cases where the circuit.scope may change,
# e.g. we may marginalise out X_1 and the length of the scope may be smaller
# but the actual scope will not have been shifted.
num_rvs = max(circuit.scope) + 1
num_idxs = sum(len(s) for s in batch_integrate_vars)

# TODO: Maybe consider using a sparse tensor
mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool, device=circuit.device)

# Catch case of only empty scopes where the following command will fail
if num_idxs == 0:
return mask

batch_idxs, rv_idxs = zip(
*((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs)
)

# Check that we have not asked to marginalise variables that are not defined
invalid_idxs = Scope(rv_idxs) - circuit.scope
if invalid_idxs:
raise ValueError(
"The variables to marginalize must be a subset of "
" the circuit scope. Invalid variables"
" not in scope: %s." % list(invalid_idxs)
)

mask[batch_idxs, rv_idxs] = True

return mask


class SamplingQuery(Query):
"""The sampling query object."""
Expand Down
197 changes: 197 additions & 0 deletions tests/backend/torch/test_queries/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool):
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)
Expand All @@ -44,3 +46,198 @@ def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, o
mar_scores2 = mar_query(mar_worlds, integrate_vars=Scope([4]))
assert mar_scores1.shape == mar_scores2.shape
assert allclose(mar_scores1, mar_scores2)


@pytest.mark.parametrize(
"semiring,fold,optimize,input_tensor",
itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]),
)
def test_batch_query_marginalize_monotonic_pc_categorical(
semiring: str, fold: bool, optimize: bool, input_tensor: bool
):
# Check using a mask with batching works
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)

tc: TorchCircuit = compiler.compile(sc)

# The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4.
inputs = torch.tensor([[[1, 0, 1, 1, 1], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5)

mar_query = IntegrateQuery(tc)
if input_tensor:
mask = torch.tensor(
[[True, True, True, True, True], [False, False, False, False, True]], dtype=torch.bool
)
else:
# Create two masks, one is marginalising out everything
# and another is marginalising out only the last variable
mask = [Scope([0, 1, 2, 3, 4]), Scope([4])]
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
mar_scores = mar_query(inputs, integrate_vars=mask)

if semiring == "sum-product":
assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func))
assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
elif semiring == "lse-sum":
mar_scores = torch.exp(mar_scores)
assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func))
assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
else:
raise ValueError('Unexpected semiring: "%s"' % semiring)


@pytest.mark.parametrize(
"semiring,fold,optimize,input_tensor",
itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]),
)
def test_batch_broadcast_query_marginalize_monotonic_pc_categorical(
semiring: str, fold: bool, optimize: bool, input_tensor: bool
):
# Check that passing a single mask results in broadcasting
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)

tc: TorchCircuit = compiler.compile(sc)

# The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4.
inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5)

mar_query = IntegrateQuery(tc)
if input_tensor:
mask = torch.tensor([False, False, False, False, True], dtype=torch.bool)
else:
# Create a single mask - this should be broadcast along the batch dim.
mask = Scope([4])
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
mar_scores = mar_query(inputs, integrate_vars=mask)

if semiring == "sum-product":
assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
elif semiring == "lse-sum":
mar_scores = torch.exp(mar_scores)
assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)]))
else:
raise ValueError('Unexpected semiring: "%s"' % semiring)


@pytest.mark.parametrize(
"input_tensor",
itertools.product([False, True]),
)
def test_batch_fails_on_out_of_scope(
input_tensor, semiring="sum-product", fold=True, optimize=True
):
# Check that passing a single mask results in broadcasting
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)

tc: TorchCircuit = compiler.compile(sc)

# The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4.
inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5)

mar_query = IntegrateQuery(tc)
if input_tensor:
# Scope 5 does not exist so this should error
mask = torch.tensor(
[[False, False, False, False, True, True], [False, False, False, False, True, True]],
dtype=torch.bool,
)
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
with pytest.raises(ValueError, match="was defined over %d != 5 variables" % mask.shape[1]):
mar_scores = mar_query(inputs, integrate_vars=mask)
else:
# Scope 5 does not exist so this should error
mask = [Scope([0]), Scope([5])]
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
with pytest.raises(ValueError, match="not in scope:.*?5"):
mar_scores = mar_query(inputs, integrate_vars=mask)


@pytest.mark.parametrize(
"input_tensor",
itertools.product([False, True]),
)
def test_batch_fails_on_wrong_batch_size(
input_tensor, semiring="sum-product", fold=True, optimize=True
):
# Check that passing a single mask results in broadcasting
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)

tc: TorchCircuit = compiler.compile(sc)

# The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4.
inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5)

mar_query = IntegrateQuery(tc)
if input_tensor:
# Input batch size is 2, passing 3 masks
mask = torch.tensor(
[
[False, False, False, False, True],
[False, False, False, False, True],
[False, False, False, False, True],
],
dtype=torch.bool,
)
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
with pytest.raises(ValueError, match="Found #inputs = 2 != 3"):
mar_scores = mar_query(inputs, integrate_vars=mask)
else:
# Input batch size is 2, passing 3 masks
mask = [Scope([0]), Scope([1]), Scope([2])]
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
with pytest.raises(ValueError, match="Found #inputs = 2 != 3"):
mar_scores = mar_query(inputs, integrate_vars=mask)


def test_batch_fails_on_wrong_tensor_dtype(semiring="sum-product", fold=True, optimize=True):
# Check that passing a single mask results in broadcasting
compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize)
# The following function computes a circuit where we have computed the
# partition function and a marginal by hand.
sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc(
return_ground_truth=True
)

tc: TorchCircuit = compiler.compile(sc)

# The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4.
inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5)

mar_query = IntegrateQuery(tc)

# Input batch size is 2, passing 3 masks
mask = torch.tensor(
[[False, False, False, False, True], [False, False, False, False, True]], dtype=torch.int32
)
# The first score should be partition function, as we marginalised out all vars.
# The second score, should be our precomputed marginal.
with pytest.raises(ValueError, match="Expected dtype of tensor to be torch.bool"):
mar_scores = mar_query(inputs, integrate_vars=mask)
Loading