From 870e56c8dd090f967e027e868d805256ebf03ac5 Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 11 Nov 2024 12:30:16 +0000 Subject: [PATCH 1/7] Allow marginalising across different scopes within a batch --- cirkit/backend/torch/queries.py | 88 ++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 13 deletions(-) diff --git a/cirkit/backend/torch/queries.py b/cirkit/backend/torch/queries.py index 19420c20..70940fa1 100644 --- a/cirkit/backend/torch/queries.py +++ b/cirkit/backend/torch/queries.py @@ -17,7 +17,16 @@ def __init__(self) -> None: class IntegrateQuery(Query): - """The integration query object.""" + """The integration query object allows marginalising out variables. + + Computes output in two forward passes: + a) The normal circuit forward pass for input x + b) The integration forward pass where all variables are marginalised + + A mask over random variables is computed based on the scopes passed as + input. This determines whether the integrated or normal circuit result + is returned for each variable. + """ def __init__(self, circuit: TorchCircuit) -> None: """Initialize an integration query object. @@ -36,7 +45,7 @@ def __init__(self, circuit: TorchCircuit) -> None: super().__init__() self._circuit = circuit - def __call__(self, x: Tensor, *, integrate_vars: Scope) -> Tensor: + def __call__(self, x: Tensor, *, integrate_vars: [Scope]) -> Tensor: """Solve an integration query, given an input batch and the variables to integrate. Args: @@ -50,19 +59,20 @@ def __call__(self, x: Tensor, *, integrate_vars: Scope) -> Tensor: where B is the batch size, O is the number of output vectors of the circuit, and K is the number of units in each output vector. """ - if not integrate_vars <= self._circuit.scope: - raise ValueError("The variables to marginalize must be a subset of the circuit scope") - integrate_vars_idx = torch.tensor(tuple(integrate_vars), device=self._circuit.device) + # Convert list of scopes to a boolean mask of dimension (B, N) where + # N is the number of variables in the circuit's scope. + integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars) + output = self._circuit.evaluate( x, module_fn=functools.partial( - IntegrateQuery._layer_fn, integrate_vars_idx=integrate_vars_idx + IntegrateQuery._layer_fn, integrate_vars_mask=integrate_vars_mask ), ) # (O, B, K) return output.transpose(0, 1) # (B, O, K) @staticmethod - def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_idx: Tensor) -> Tensor: + def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> Tensor: # Evaluate a layer: if it is not an input layer, then evaluate it in the usual # feed-forward way. Otherwise, use the variables to integrate to solve the marginal # queries on the input layers. @@ -71,14 +81,26 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_idx: Tensor) -> Te return output if layer.num_variables > 1: raise NotImplementedError("Integration of multivariate input layers is not supported") - # integration_mask: Boolean mask of shape (F, 1) - integration_mask = torch.isin(layer.scope_idx, integrate_vars_idx) + # integrate_vars_mask is a boolean tensor of dim (B, N) + # where N is the number of variables in the scope of the whole circuit. + # + # layer.scope_idx contains a subset of the variable_idxs of the scope + # but may be a reshaped tensor; the shape and order of the variables may be different. + # + # as such, we need to use the idxs in layer.scope_idx to lookup the values from + # the integrate_vars_mask - this will return the correct shape and values. + # + # if integrate_vars_mask was a vector, we could do integrate_vars_mask[layer.scope_idx] + # the vmap below applies the above across the B dimension + + # integration_mask has dimension (B, F, Ko) + integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask) + # permute to match integration_output: integration_mask has dimension (F, B, Ko) + integration_mask = integration_mask.permute([1,0,2]) + if not torch.any(integration_mask).item(): return output - # output: output of the layer of shape (F, B, Ko) - # integration_mask: Boolean mask of shape (F, 1, 1) - # integration_output: result of the integration of the layer of shape (F, 1, Ko) - integration_mask = integration_mask.unsqueeze(dim=2) + integration_output = layer.integrate() # Use the integration mask to select which output should be the result of # an integration operation, and which should not be @@ -86,6 +108,46 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_idx: Tensor) -> Te # circuit is folded or unfolded return torch.where(integration_mask, integration_output, output) + @staticmethod + def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): + """Accepts a batch of scopes and returns a boolean mask as a tensor with + True in positions of specified scope indices and False otherwise. + """ + # If we passed a single scope, assume B = 1 + if isinstance(batch_integrate_vars, Scope): + batch_integrate_vars = [batch_integrate_vars] + + batch_size = len(batch_integrate_vars) + # There are cases where the circuit.scope may change, + # e.g. we may marginalise out X_1 and the length of the scope may be smaller + # but the actual scope will not have been shifted. + num_rvs = max(circuit.scope) + 1 + num_idxs = sum(len(s) for s in batch_integrate_vars) + + # TODO: Maybe consider using a sparse tensor + mask = torch.zeros((batch_size, num_rvs), + dtype=torch.bool, + device=circuit.device) + + # Catch case of only empty scopes where the following command will fail + if num_idxs == 0: + return mask + + batch_idxs, rv_idxs = zip(*((i, idx) + for i, idxs in enumerate(batch_integrate_vars) + for idx in idxs if idxs)) + + # Check that we have not asked to marginalise variables that are not defined + invalid_idxs = Scope(rv_idxs) - circuit.scope + if invalid_idxs: + raise ValueError("The variables to marginalize must be a subset of " + " the circuit scope. Invalid variables" + " not in scope: %s." % list(invalid_idxs)) + + mask[batch_idxs, rv_idxs] = True + + return mask + class SamplingQuery(Query): """The sampling query object.""" From 8a6718ca693dc8795ccdd0b52a651487c0daa761 Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 11 Nov 2024 12:31:14 +0000 Subject: [PATCH 2/7] Add tests for batched integration masks --- .../torch/test_queries/test_integration.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/backend/torch/test_queries/test_integration.py b/tests/backend/torch/test_queries/test_integration.py index e51a5aab..a2a8c5b7 100644 --- a/tests/backend/torch/test_queries/test_integration.py +++ b/tests/backend/torch/test_queries/test_integration.py @@ -18,6 +18,8 @@ ) def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool): compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( return_ground_truth=True ) @@ -44,3 +46,76 @@ def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, o mar_scores2 = mar_query(mar_worlds, integrate_vars=Scope([4])) assert mar_scores1.shape == mar_scores2.shape assert allclose(mar_scores1, mar_scores2) + + +@pytest.mark.parametrize( + "semiring,fold,optimize", + itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), +) +def test_batch_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool): + # Check using a mask with batching works + compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. + sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( + return_ground_truth=True + ) + + tc: TorchCircuit = compiler.compile(sc) + + # The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4. + inputs = torch.tensor([[[1, 0, 1, 1, 1], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) + + mar_query = IntegrateQuery(tc) + # Create two masks, one is marginalising out everything + # and another is marginalising out only the last variable + mask = [Scope([0, 1, 2, 3, 4]), Scope([4])] + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + mar_scores = mar_query(inputs, integrate_vars=mask) + + if semiring == 'sum-product': + assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func)) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + elif semiring == 'lse-sum': + mar_scores = torch.exp(mar_scores) + assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func)) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + else: + raise ValueError('Unexpected semiring: "%s"' % semiring) + + +@pytest.mark.parametrize( + "semiring,fold,optimize", + itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), +) +def test_batch_broadcast_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool): + # Check that passing a single mask results in broadcasting + compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. + sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( + return_ground_truth=True + ) + + tc: TorchCircuit = compiler.compile(sc) + + # The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4. + inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) + + mar_query = IntegrateQuery(tc) + # Create a single mask - this should be broadcast along the batch dim. + mask = Scope([4]) + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + mar_scores = mar_query(inputs, integrate_vars=mask) + + if semiring == 'sum-product': + assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + elif semiring == 'lse-sum': + mar_scores = torch.exp(mar_scores) + assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + else: + raise ValueError('Unexpected semiring: "%s"' % semiring) From bd6062f90ddb86a85c3751350c3412ab0b9dbe0a Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 11 Nov 2024 12:47:12 +0000 Subject: [PATCH 3/7] Check integration errors when var out of scope --- cirkit/backend/torch/queries.py | 2 +- .../torch/test_queries/test_integration.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/cirkit/backend/torch/queries.py b/cirkit/backend/torch/queries.py index 70940fa1..668564f3 100644 --- a/cirkit/backend/torch/queries.py +++ b/cirkit/backend/torch/queries.py @@ -140,7 +140,7 @@ def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): # Check that we have not asked to marginalise variables that are not defined invalid_idxs = Scope(rv_idxs) - circuit.scope if invalid_idxs: - raise ValueError("The variables to marginalize must be a subset of " + raise ValueError("The variables to marginalize must be a subset of" " the circuit scope. Invalid variables" " not in scope: %s." % list(invalid_idxs)) diff --git a/tests/backend/torch/test_queries/test_integration.py b/tests/backend/torch/test_queries/test_integration.py index a2a8c5b7..5d155dbd 100644 --- a/tests/backend/torch/test_queries/test_integration.py +++ b/tests/backend/torch/test_queries/test_integration.py @@ -119,3 +119,26 @@ def test_batch_broadcast_query_marginalize_monotonic_pc_categorical(semiring: st assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) else: raise ValueError('Unexpected semiring: "%s"' % semiring) + + +def test_batch_fails_on_out_of_scope(semiring='sum-product', fold=True, optimize=True): + # Check that passing a single mask results in broadcasting + compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. + sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( + return_ground_truth=True + ) + + tc: TorchCircuit = compiler.compile(sc) + + # The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4. + inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) + + mar_query = IntegrateQuery(tc) + # Scope 5 does not exist so this should error + mask = [Scope([0]), Scope([5])] + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + with pytest.raises(ValueError, match='not in scope:.*?5'): + mar_scores = mar_query(inputs, integrate_vars=mask) From 190effa3b6429443cc98529ccb51b6b12f43947a Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 11 Nov 2024 13:11:07 +0000 Subject: [PATCH 4/7] Fix code style --- cirkit/backend/torch/queries.py | 24 +++++++------- .../torch/test_queries/test_integration.py | 32 +++++++++++-------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/cirkit/backend/torch/queries.py b/cirkit/backend/torch/queries.py index 668564f3..f79710e0 100644 --- a/cirkit/backend/torch/queries.py +++ b/cirkit/backend/torch/queries.py @@ -60,7 +60,7 @@ def __call__(self, x: Tensor, *, integrate_vars: [Scope]) -> Tensor: K is the number of units in each output vector. """ # Convert list of scopes to a boolean mask of dimension (B, N) where - # N is the number of variables in the circuit's scope. + # N is the number of variables in the circuit's scope. integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars) output = self._circuit.evaluate( @@ -96,8 +96,8 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> T # integration_mask has dimension (B, F, Ko) integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask) # permute to match integration_output: integration_mask has dimension (F, B, Ko) - integration_mask = integration_mask.permute([1,0,2]) - + integration_mask = integration_mask.permute([1, 0, 2]) + if not torch.any(integration_mask).item(): return output @@ -125,24 +125,24 @@ def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): num_idxs = sum(len(s) for s in batch_integrate_vars) # TODO: Maybe consider using a sparse tensor - mask = torch.zeros((batch_size, num_rvs), - dtype=torch.bool, - device=circuit.device) + mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool, device=circuit.device) # Catch case of only empty scopes where the following command will fail if num_idxs == 0: return mask - batch_idxs, rv_idxs = zip(*((i, idx) - for i, idxs in enumerate(batch_integrate_vars) - for idx in idxs if idxs)) + batch_idxs, rv_idxs = zip( + *((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs) + ) # Check that we have not asked to marginalise variables that are not defined invalid_idxs = Scope(rv_idxs) - circuit.scope if invalid_idxs: - raise ValueError("The variables to marginalize must be a subset of" - " the circuit scope. Invalid variables" - " not in scope: %s." % list(invalid_idxs)) + raise ValueError( + "The variables to marginalize must be a subset of" + " the circuit scope. Invalid variables" + " not in scope: %s." % list(invalid_idxs) + ) mask[batch_idxs, rv_idxs] = True diff --git a/tests/backend/torch/test_queries/test_integration.py b/tests/backend/torch/test_queries/test_integration.py index 5d155dbd..f49fa72f 100644 --- a/tests/backend/torch/test_queries/test_integration.py +++ b/tests/backend/torch/test_queries/test_integration.py @@ -52,7 +52,9 @@ def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, o "semiring,fold,optimize", itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), ) -def test_batch_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool): +def test_batch_query_marginalize_monotonic_pc_categorical( + semiring: str, fold: bool, optimize: bool +): # Check using a mask with batching works compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -74,13 +76,13 @@ def test_batch_query_marginalize_monotonic_pc_categorical(semiring: str, fold: b # The second score, should be our precomputed marginal. mar_scores = mar_query(inputs, integrate_vars=mask) - if semiring == 'sum-product': + if semiring == "sum-product": assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func)) - assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) - elif semiring == 'lse-sum': + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) + elif semiring == "lse-sum": mar_scores = torch.exp(mar_scores) assert torch.isclose(mar_scores[0], torch.tensor(gt_partition_func)) - assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) else: raise ValueError('Unexpected semiring: "%s"' % semiring) @@ -89,7 +91,9 @@ def test_batch_query_marginalize_monotonic_pc_categorical(semiring: str, fold: b "semiring,fold,optimize", itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), ) -def test_batch_broadcast_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, optimize: bool): +def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( + semiring: str, fold: bool, optimize: bool +): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -110,18 +114,18 @@ def test_batch_broadcast_query_marginalize_monotonic_pc_categorical(semiring: st # The second score, should be our precomputed marginal. mar_scores = mar_query(inputs, integrate_vars=mask) - if semiring == 'sum-product': - assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) - assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) - elif semiring == 'lse-sum': + if semiring == "sum-product": + assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) + elif semiring == "lse-sum": mar_scores = torch.exp(mar_scores) - assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) - assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs['mar'][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[0], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) + assert torch.isclose(mar_scores[1], torch.tensor(gt_outputs["mar"][(1, 0, 1, 1, None)])) else: raise ValueError('Unexpected semiring: "%s"' % semiring) -def test_batch_fails_on_out_of_scope(semiring='sum-product', fold=True, optimize=True): +def test_batch_fails_on_out_of_scope(semiring="sum-product", fold=True, optimize=True): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -140,5 +144,5 @@ def test_batch_fails_on_out_of_scope(semiring='sum-product', fold=True, optimize mask = [Scope([0]), Scope([5])] # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. - with pytest.raises(ValueError, match='not in scope:.*?5'): + with pytest.raises(ValueError, match="not in scope:.*?5"): mar_scores = mar_query(inputs, integrate_vars=mask) From eed72d7dba8f0184f4c7dcb30d4b908f6352604c Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 25 Nov 2024 10:20:41 +0000 Subject: [PATCH 5/7] Support passing boolean tensor mask as input --- cirkit/backend/torch/queries.py | 68 ++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/cirkit/backend/torch/queries.py b/cirkit/backend/torch/queries.py index f79710e0..e64bc32d 100644 --- a/cirkit/backend/torch/queries.py +++ b/cirkit/backend/torch/queries.py @@ -1,5 +1,6 @@ import functools from abc import ABC +from collections.abc import Iterable import torch from torch import Tensor @@ -45,7 +46,7 @@ def __init__(self, circuit: TorchCircuit) -> None: super().__init__() self._circuit = circuit - def __call__(self, x: Tensor, *, integrate_vars: [Scope]) -> Tensor: + def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope]) -> Tensor: """Solve an integration query, given an input batch and the variables to integrate. Args: @@ -53,15 +54,48 @@ def __call__(self, x: Tensor, *, integrate_vars: [Scope]) -> Tensor: channels per variable, and D is the number of variables. integrate_vars: The variables to integrate. It must be a subset of the variables on which the circuit given in the constructor is defined on. - + The format can be one of the following three: + 1. Tensor of shape (B, D) where B is the batch size and D is the number of + variables in the scope of the circuit. Its dtype should be torch.bool + and have True in the positions of random variables that should be + marginalised out and False elsewhere. + 2. Scope, in this case the same integration mask is applied for all entries + of the batch + 3. List of Scopes, where the length of the list must be either 1 or B. If + the list has length 1, behaves as above. Returns: The result of the integration query, given as a tensor of shape (B, O, K), where B is the batch size, O is the number of output vectors of the circuit, and K is the number of units in each output vector. """ - # Convert list of scopes to a boolean mask of dimension (B, N) where - # N is the number of variables in the circuit's scope. - integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars) + if isinstance(integrate_vars, Tensor): + # Check type of tensor is boolean + if integrate_vars.dtype != torch.bool: + raise ValueError('Expected dtype of tensor to be torch.bool, got %s' + % integrate_vars.dtype) + # If single dimensional tensor, assume batch size = 1 + if len(integrate_vars.shape) == 1: + integrate_vars = torch.unsqueeze(integrate_vars, 0) + # If the scope is correct, proceed, otherwise error + num_vars = max(self._circuit.scope) + 1 + if integrate_vars.shape[1] == num_vars: + integrate_vars_mask = integrate_vars + else: + raise ValueError('Circuit scope has %d variables but integrate_vars' + ' was defined over %d != %d variables.' % + (num_vars, integrate_vars.shape[1], num_vars)) + else: + # Convert list of scopes to a boolean mask of dimension (B, N) where + # N is the number of variables in the circuit's scope. + integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars) + + # Check batch sizes of input x and mask are compatible + if integrate_vars_mask.shape[0] not in (1, x.shape[0]): + raise ValueError('The number of scopes to integrate over must' + ' either match the batch size of x, or be 1 if you' + ' want to broadcast.' + ' Found #inputs = %d != %d = len(integrate_vars)' + % (x.shape[0], integrate_vars_mask.shape[0])) output = self._circuit.evaluate( x, @@ -96,8 +130,8 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> T # integration_mask has dimension (B, F, Ko) integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask) # permute to match integration_output: integration_mask has dimension (F, B, Ko) - integration_mask = integration_mask.permute([1, 0, 2]) - + integration_mask = integration_mask.permute([1,0,2]) + if not torch.any(integration_mask).item(): return output @@ -117,7 +151,7 @@ def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): if isinstance(batch_integrate_vars, Scope): batch_integrate_vars = [batch_integrate_vars] - batch_size = len(batch_integrate_vars) + batch_size = len(tuple(batch_integrate_vars)) # There are cases where the circuit.scope may change, # e.g. we may marginalise out X_1 and the length of the scope may be smaller # but the actual scope will not have been shifted. @@ -125,24 +159,24 @@ def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): num_idxs = sum(len(s) for s in batch_integrate_vars) # TODO: Maybe consider using a sparse tensor - mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool, device=circuit.device) + mask = torch.zeros((batch_size, num_rvs), + dtype=torch.bool, + device=circuit.device) # Catch case of only empty scopes where the following command will fail if num_idxs == 0: return mask - batch_idxs, rv_idxs = zip( - *((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs) - ) + batch_idxs, rv_idxs = zip(*((i, idx) + for i, idxs in enumerate(batch_integrate_vars) + for idx in idxs if idxs)) # Check that we have not asked to marginalise variables that are not defined invalid_idxs = Scope(rv_idxs) - circuit.scope if invalid_idxs: - raise ValueError( - "The variables to marginalize must be a subset of" - " the circuit scope. Invalid variables" - " not in scope: %s." % list(invalid_idxs) - ) + raise ValueError("The variables to marginalize must be a subset of " + " the circuit scope. Invalid variables" + " not in scope: %s." % list(invalid_idxs)) mask[batch_idxs, rv_idxs] = True From dc552c630accaa88206e8c77bd738c4f5a23d8b8 Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 25 Nov 2024 10:24:59 +0000 Subject: [PATCH 6/7] Add tests for tensors and dtype checks --- .../torch/test_queries/test_integration.py | 116 +++++++++++++++--- 1 file changed, 101 insertions(+), 15 deletions(-) diff --git a/tests/backend/torch/test_queries/test_integration.py b/tests/backend/torch/test_queries/test_integration.py index f49fa72f..7dc956e7 100644 --- a/tests/backend/torch/test_queries/test_integration.py +++ b/tests/backend/torch/test_queries/test_integration.py @@ -49,11 +49,11 @@ def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, o @pytest.mark.parametrize( - "semiring,fold,optimize", - itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), + "semiring,fold,optimize,input_tensor", + itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]), ) def test_batch_query_marginalize_monotonic_pc_categorical( - semiring: str, fold: bool, optimize: bool + semiring: str, fold: bool, optimize: bool, input_tensor: bool ): # Check using a mask with batching works compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) @@ -69,9 +69,14 @@ def test_batch_query_marginalize_monotonic_pc_categorical( inputs = torch.tensor([[[1, 0, 1, 1, 1], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) mar_query = IntegrateQuery(tc) - # Create two masks, one is marginalising out everything - # and another is marginalising out only the last variable - mask = [Scope([0, 1, 2, 3, 4]), Scope([4])] + if input_tensor: + mask = torch.tensor([[True, True, True, True, True], + [False, False, False, False, True]], + dtype=torch.bool) + else: + # Create two masks, one is marginalising out everything + # and another is marginalising out only the last variable + mask = [Scope([0, 1, 2, 3, 4]), Scope([4])] # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. mar_scores = mar_query(inputs, integrate_vars=mask) @@ -88,11 +93,11 @@ def test_batch_query_marginalize_monotonic_pc_categorical( @pytest.mark.parametrize( - "semiring,fold,optimize", - itertools.product(["lse-sum", "sum-product"], [False, True], [False, True]), + "semiring,fold,optimize,input_tensor", + itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]), ) def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( - semiring: str, fold: bool, optimize: bool + semiring: str, fold: bool, optimize: bool, input_tensor: bool ): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) @@ -108,8 +113,11 @@ def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) mar_query = IntegrateQuery(tc) - # Create a single mask - this should be broadcast along the batch dim. - mask = Scope([4]) + if input_tensor: + mask = torch.tensor([False, False, False, False, True], dtype=torch.bool) + else: + # Create a single mask - this should be broadcast along the batch dim. + mask = Scope([4]) # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. mar_scores = mar_query(inputs, integrate_vars=mask) @@ -125,7 +133,11 @@ def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( raise ValueError('Unexpected semiring: "%s"' % semiring) -def test_batch_fails_on_out_of_scope(semiring="sum-product", fold=True, optimize=True): +@pytest.mark.parametrize( + "input_tensor", + itertools.product([False, True]), +) +def test_batch_fails_on_out_of_scope(input_tensor, semiring="sum-product", fold=True, optimize=True): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -140,9 +152,83 @@ def test_batch_fails_on_out_of_scope(semiring="sum-product", fold=True, optimize inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) mar_query = IntegrateQuery(tc) - # Scope 5 does not exist so this should error - mask = [Scope([0]), Scope([5])] + if input_tensor: + # Scope 5 does not exist so this should error + mask = torch.tensor([[False, False, False, False, True, True], + [False, False, False, False, True, True]], + dtype=torch.bool) + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + with pytest.raises(ValueError, match="was defined over %d != 5 variables" % mask.shape[1]): + mar_scores = mar_query(inputs, integrate_vars=mask) + else: + # Scope 5 does not exist so this should error + mask = [Scope([0]), Scope([5])] + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + with pytest.raises(ValueError, match="not in scope:.*?5"): + mar_scores = mar_query(inputs, integrate_vars=mask) + + +@pytest.mark.parametrize( + "input_tensor", + itertools.product([False, True]), +) +def test_batch_fails_on_wrong_batch_size(input_tensor, semiring="sum-product", fold=True, optimize=True): + # Check that passing a single mask results in broadcasting + compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. + sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( + return_ground_truth=True + ) + + tc: TorchCircuit = compiler.compile(sc) + + # The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4. + inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) + + mar_query = IntegrateQuery(tc) + if input_tensor: + # Input batch size is 2, passing 3 masks + mask = torch.tensor([[False, False, False, False, True], + [False, False, False, False, True], + [False, False, False, False, True]], + dtype=torch.bool) + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + with pytest.raises(ValueError, match="Found #inputs = 2 != 3"): + mar_scores = mar_query(inputs, integrate_vars=mask) + else: + # Input batch size is 2, passing 3 masks + mask = [Scope([0]), Scope([1]), Scope([2])] + # The first score should be partition function, as we marginalised out all vars. + # The second score, should be our precomputed marginal. + with pytest.raises(ValueError, match="Found #inputs = 2 != 3"): + mar_scores = mar_query(inputs, integrate_vars=mask) + + +def test_batch_fails_on_wrong_tensor_dtype(semiring="sum-product", fold=True, optimize=True): + # Check that passing a single mask results in broadcasting + compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) + # The following function computes a circuit where we have computed the + # partition function and a marginal by hand. + sc, gt_outputs, gt_partition_func = build_monotonic_structured_categorical_cpt_pc( + return_ground_truth=True + ) + + tc: TorchCircuit = compiler.compile(sc) + + # The marginal has been computed for (1, 0, 1, 1, None) -- so marginalising var 4. + inputs = torch.tensor([[[1, 0, 1, 1, 0], [1, 0, 1, 1, 1]]], dtype=torch.int64).view(2, 1, 5) + + mar_query = IntegrateQuery(tc) + + # Input batch size is 2, passing 3 masks + mask = torch.tensor([[False, False, False, False, True], + [False, False, False, False, True]], + dtype=torch.int32) # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. - with pytest.raises(ValueError, match="not in scope:.*?5"): + with pytest.raises(ValueError, match="Expected dtype of tensor to be torch.bool"): mar_scores = mar_query(inputs, integrate_vars=mask) From 65b482f634147e9bb3b505d26f4f31ae42f9ec1b Mon Sep 17 00:00:00 2001 From: Andreas Grivas Date: Mon, 25 Nov 2024 10:26:03 +0000 Subject: [PATCH 7/7] Run through black --- cirkit/backend/torch/queries.py | 49 ++++++++++--------- .../torch/test_queries/test_integration.py | 43 +++++++++------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/cirkit/backend/torch/queries.py b/cirkit/backend/torch/queries.py index e64bc32d..f48627f4 100644 --- a/cirkit/backend/torch/queries.py +++ b/cirkit/backend/torch/queries.py @@ -71,8 +71,9 @@ def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope if isinstance(integrate_vars, Tensor): # Check type of tensor is boolean if integrate_vars.dtype != torch.bool: - raise ValueError('Expected dtype of tensor to be torch.bool, got %s' - % integrate_vars.dtype) + raise ValueError( + "Expected dtype of tensor to be torch.bool, got %s" % integrate_vars.dtype + ) # If single dimensional tensor, assume batch size = 1 if len(integrate_vars.shape) == 1: integrate_vars = torch.unsqueeze(integrate_vars, 0) @@ -81,21 +82,25 @@ def __call__(self, x: Tensor, *, integrate_vars: Tensor | Scope | Iterable[Scope if integrate_vars.shape[1] == num_vars: integrate_vars_mask = integrate_vars else: - raise ValueError('Circuit scope has %d variables but integrate_vars' - ' was defined over %d != %d variables.' % - (num_vars, integrate_vars.shape[1], num_vars)) + raise ValueError( + "Circuit scope has %d variables but integrate_vars" + " was defined over %d != %d variables." + % (num_vars, integrate_vars.shape[1], num_vars) + ) else: # Convert list of scopes to a boolean mask of dimension (B, N) where - # N is the number of variables in the circuit's scope. + # N is the number of variables in the circuit's scope. integrate_vars_mask = IntegrateQuery.scopes_to_mask(self._circuit, integrate_vars) # Check batch sizes of input x and mask are compatible if integrate_vars_mask.shape[0] not in (1, x.shape[0]): - raise ValueError('The number of scopes to integrate over must' - ' either match the batch size of x, or be 1 if you' - ' want to broadcast.' - ' Found #inputs = %d != %d = len(integrate_vars)' - % (x.shape[0], integrate_vars_mask.shape[0])) + raise ValueError( + "The number of scopes to integrate over must" + " either match the batch size of x, or be 1 if you" + " want to broadcast." + " Found #inputs = %d != %d = len(integrate_vars)" + % (x.shape[0], integrate_vars_mask.shape[0]) + ) output = self._circuit.evaluate( x, @@ -130,8 +135,8 @@ def _layer_fn(layer: TorchLayer, x: Tensor, *, integrate_vars_mask: Tensor) -> T # integration_mask has dimension (B, F, Ko) integration_mask = torch.vmap(lambda x: x[layer.scope_idx])(integrate_vars_mask) # permute to match integration_output: integration_mask has dimension (F, B, Ko) - integration_mask = integration_mask.permute([1,0,2]) - + integration_mask = integration_mask.permute([1, 0, 2]) + if not torch.any(integration_mask).item(): return output @@ -159,24 +164,24 @@ def scopes_to_mask(circuit, batch_integrate_vars: [Scope]): num_idxs = sum(len(s) for s in batch_integrate_vars) # TODO: Maybe consider using a sparse tensor - mask = torch.zeros((batch_size, num_rvs), - dtype=torch.bool, - device=circuit.device) + mask = torch.zeros((batch_size, num_rvs), dtype=torch.bool, device=circuit.device) # Catch case of only empty scopes where the following command will fail if num_idxs == 0: return mask - batch_idxs, rv_idxs = zip(*((i, idx) - for i, idxs in enumerate(batch_integrate_vars) - for idx in idxs if idxs)) + batch_idxs, rv_idxs = zip( + *((i, idx) for i, idxs in enumerate(batch_integrate_vars) for idx in idxs if idxs) + ) # Check that we have not asked to marginalise variables that are not defined invalid_idxs = Scope(rv_idxs) - circuit.scope if invalid_idxs: - raise ValueError("The variables to marginalize must be a subset of " - " the circuit scope. Invalid variables" - " not in scope: %s." % list(invalid_idxs)) + raise ValueError( + "The variables to marginalize must be a subset of " + " the circuit scope. Invalid variables" + " not in scope: %s." % list(invalid_idxs) + ) mask[batch_idxs, rv_idxs] = True diff --git a/tests/backend/torch/test_queries/test_integration.py b/tests/backend/torch/test_queries/test_integration.py index 7dc956e7..07cfee02 100644 --- a/tests/backend/torch/test_queries/test_integration.py +++ b/tests/backend/torch/test_queries/test_integration.py @@ -53,7 +53,7 @@ def test_query_marginalize_monotonic_pc_categorical(semiring: str, fold: bool, o itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]), ) def test_batch_query_marginalize_monotonic_pc_categorical( - semiring: str, fold: bool, optimize: bool, input_tensor: bool + semiring: str, fold: bool, optimize: bool, input_tensor: bool ): # Check using a mask with batching works compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) @@ -70,9 +70,9 @@ def test_batch_query_marginalize_monotonic_pc_categorical( mar_query = IntegrateQuery(tc) if input_tensor: - mask = torch.tensor([[True, True, True, True, True], - [False, False, False, False, True]], - dtype=torch.bool) + mask = torch.tensor( + [[True, True, True, True, True], [False, False, False, False, True]], dtype=torch.bool + ) else: # Create two masks, one is marginalising out everything # and another is marginalising out only the last variable @@ -97,7 +97,7 @@ def test_batch_query_marginalize_monotonic_pc_categorical( itertools.product(["lse-sum", "sum-product"], [False, True], [False, True], [False, True]), ) def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( - semiring: str, fold: bool, optimize: bool, input_tensor: bool + semiring: str, fold: bool, optimize: bool, input_tensor: bool ): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) @@ -137,7 +137,9 @@ def test_batch_broadcast_query_marginalize_monotonic_pc_categorical( "input_tensor", itertools.product([False, True]), ) -def test_batch_fails_on_out_of_scope(input_tensor, semiring="sum-product", fold=True, optimize=True): +def test_batch_fails_on_out_of_scope( + input_tensor, semiring="sum-product", fold=True, optimize=True +): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -154,9 +156,10 @@ def test_batch_fails_on_out_of_scope(input_tensor, semiring="sum-product", fold= mar_query = IntegrateQuery(tc) if input_tensor: # Scope 5 does not exist so this should error - mask = torch.tensor([[False, False, False, False, True, True], - [False, False, False, False, True, True]], - dtype=torch.bool) + mask = torch.tensor( + [[False, False, False, False, True, True], [False, False, False, False, True, True]], + dtype=torch.bool, + ) # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. with pytest.raises(ValueError, match="was defined over %d != 5 variables" % mask.shape[1]): @@ -174,7 +177,9 @@ def test_batch_fails_on_out_of_scope(input_tensor, semiring="sum-product", fold= "input_tensor", itertools.product([False, True]), ) -def test_batch_fails_on_wrong_batch_size(input_tensor, semiring="sum-product", fold=True, optimize=True): +def test_batch_fails_on_wrong_batch_size( + input_tensor, semiring="sum-product", fold=True, optimize=True +): # Check that passing a single mask results in broadcasting compiler = TorchCompiler(semiring=semiring, fold=fold, optimize=optimize) # The following function computes a circuit where we have computed the @@ -191,10 +196,14 @@ def test_batch_fails_on_wrong_batch_size(input_tensor, semiring="sum-product", f mar_query = IntegrateQuery(tc) if input_tensor: # Input batch size is 2, passing 3 masks - mask = torch.tensor([[False, False, False, False, True], - [False, False, False, False, True], - [False, False, False, False, True]], - dtype=torch.bool) + mask = torch.tensor( + [ + [False, False, False, False, True], + [False, False, False, False, True], + [False, False, False, False, True], + ], + dtype=torch.bool, + ) # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. with pytest.raises(ValueError, match="Found #inputs = 2 != 3"): @@ -225,9 +234,9 @@ def test_batch_fails_on_wrong_tensor_dtype(semiring="sum-product", fold=True, op mar_query = IntegrateQuery(tc) # Input batch size is 2, passing 3 masks - mask = torch.tensor([[False, False, False, False, True], - [False, False, False, False, True]], - dtype=torch.int32) + mask = torch.tensor( + [[False, False, False, False, True], [False, False, False, False, True]], dtype=torch.int32 + ) # The first score should be partition function, as we marginalised out all vars. # The second score, should be our precomputed marginal. with pytest.raises(ValueError, match="Expected dtype of tensor to be torch.bool"):