Skip to content

Commit

Permalink
add sanity checks to the symbolic representation
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Nov 16, 2024
1 parent 5dbfeca commit f096e9a
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 73 deletions.
62 changes: 43 additions & 19 deletions cirkit/backend/torch/parameters/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,26 +663,38 @@ def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
class TorchGaussianProductMean(TorchParameterOp):
def __init__(
self,
in_gaussian1_shape: tuple[int, ...],
in_gaussian2_shape: tuple[int, ...],
in_mean1_shape: tuple[int, ...],
in_stddev1_shape: tuple[int, ...],
in_mean2_shape: tuple[int, ...],
in_stddev2_shape: tuple[int, ...],
*,
num_folds: int = 1,
) -> None:
assert in_gaussian1_shape[1] == in_gaussian2_shape[1]
super().__init__(in_gaussian1_shape, in_gaussian2_shape, num_folds=num_folds)
assert in_mean1_shape == in_stddev1_shape
assert in_mean2_shape == in_stddev2_shape
assert in_mean1_shape[1] == in_mean2_shape[1]
assert in_stddev1_shape[1] == in_stddev2_shape[1]
super().__init__(
in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
)

@property
def shape(self) -> tuple[int, ...]:
return (
self.in_shapes[0][0] * self.in_shapes[1][0],
self.in_shapes[0][0] * self.in_shapes[2][0],
self.in_shapes[0][1],
)

@property
def config(self) -> dict[str, Any]:
return {"in_gaussian1_shape": self.in_shapes[0], "in_gaussian2_shape": self.in_shapes[1]}
return {
"in_mean1_shape": self.in_shapes[0],
"in_stddev1_shape": self.in_shapes[1],
"in_mean2_shape": self.in_shapes[2],
"in_stddev2_shape": self.in_shapes[3],
}

def forward(self, mean1: Tensor, mean2: Tensor, stddev1: Tensor, stddev2: Tensor) -> Tensor:
def forward(self, mean1: Tensor, stddev1: Tensor, mean2: Tensor, stddev2: Tensor) -> Tensor:
var1 = torch.square(stddev1) # (F, K1, C)
var2 = torch.square(stddev2) # (F, K2, C)
inv_var12 = torch.reciprocal(
Expand All @@ -697,13 +709,13 @@ def forward(self, mean1: Tensor, mean2: Tensor, stddev1: Tensor, stddev2: Tensor
class TorchGaussianProductStddev(TorchBinaryParameterOp):
def __init__(
self,
in_gaussian1_shape: tuple[int, ...],
in_gaussian2_shape: tuple[int, ...],
in_stddev1_shape: tuple[int, ...],
in_stddev2_shape: tuple[int, ...],
*,
num_folds: int = 1,
) -> None:
assert in_gaussian1_shape[1] == in_gaussian2_shape[1]
super().__init__(in_gaussian1_shape, in_gaussian2_shape, num_folds=num_folds)
assert in_stddev1_shape[1] == in_stddev2_shape[1]
super().__init__(in_stddev1_shape, in_stddev2_shape, num_folds=num_folds)

@property
def shape(self) -> tuple[int, ...]:
Expand All @@ -714,7 +726,7 @@ def shape(self) -> tuple[int, ...]:

@property
def config(self) -> dict[str, Any]:
return {"in_gaussian1_shape": self.in_shapes[0], "in_gaussian2_shape": self.in_shapes[1]}
return {"in_stddev1_shape": self.in_shapes[0], "in_stddev2_shape": self.in_shapes[1]}

def forward(self, stddev1: Tensor, stddev2: Tensor) -> Tensor:
var1 = torch.square(stddev1) # (F, K1, C)
Expand All @@ -728,31 +740,43 @@ def forward(self, stddev1: Tensor, stddev2: Tensor) -> Tensor:
class TorchGaussianProductLogPartition(TorchParameterOp):
def __init__(
self,
in_gaussian1_shape: tuple[int, ...],
in_gaussian2_shape: tuple[int, ...],
in_mean1_shape: tuple[int, ...],
in_stddev1_shape: tuple[int, ...],
in_mean2_shape: tuple[int, ...],
in_stddev2_shape: tuple[int, ...],
*,
num_folds: int = 1,
) -> None:
assert in_gaussian1_shape[1] == in_gaussian2_shape[1]
super().__init__(in_gaussian1_shape, in_gaussian2_shape, num_folds=num_folds)
assert in_mean1_shape == in_stddev1_shape
assert in_mean2_shape == in_stddev2_shape
assert in_mean1_shape[1] == in_mean2_shape[1]
assert in_stddev1_shape[1] == in_stddev2_shape[1]
super().__init__(
in_mean1_shape, in_stddev1_shape, in_mean2_shape, in_stddev2_shape, num_folds=num_folds
)
self._log_two_pi = np.log(2.0 * np.pi)

@property
def shape(self) -> tuple[int, ...]:
return (
self.in_shapes[0][0] * self.in_shapes[1][0],
self.in_shapes[0][0] * self.in_shapes[2][0],
self.in_shapes[0][1],
)

@property
def config(self) -> dict[str, Any]:
return {"in_gaussian1_shape": self.in_shapes[0], "in_gaussian2_shape": self.in_shapes[1]}
return {
"in_mean1_shape": self.in_shapes[0],
"in_stddev1_shape": self.in_shapes[1],
"in_mean2_shape": self.in_shapes[2],
"in_stddev2_shape": self.in_shapes[3],
}

def forward(
self,
mean1: Tensor,
mean2: Tensor,
stddev1: Tensor,
mean2: Tensor,
stddev2: Tensor,
) -> Tensor:
var1 = torch.square(stddev1) # (F, K1, C)
Expand Down
18 changes: 14 additions & 4 deletions cirkit/symbolic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,25 @@ def __init__(
self.num_channels = num_channels
self.operation = operation

# Build scopes bottom-up
# Build scopes bottom-up, and check the consistency of the layers, w.r.t.
# the arity and the number of input and output units
self._scopes: dict[Layer, Scope] = {}
for sl in self.topological_ordering():
if isinstance(sl, InputLayer):
self._scopes[sl] = sl.scope
continue
self._scopes[sl] = Scope.union(
*tuple(self._scopes[sli] for sli in self.layer_inputs(sl))
)
sl_ins = self.layer_inputs(sl)
self._scopes[sl] = Scope.union(*tuple(self._scopes[sli] for sli in sl_ins))
if sl.arity != len(sl_ins):
raise ValueError(
f"{sl}: expected arity {sl.arity}, " f"but found {len(sl_ins)} input layers"
)
sl_ins_units = [sli.num_output_units for sli in sl_ins]
if any(sl.num_input_units != num_units for num_units in sl_ins_units):
raise ValueError(
f"{sl}: expected number of input units {sl.num_input_units}, "
f"but found input layers {sl_ins}"
)
self.scope = Scope.union(*tuple(self._scopes[sl] for sl in self.outputs))

@property
Expand Down
7 changes: 4 additions & 3 deletions cirkit/symbolic/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def evidence(
Raises:
ValueError: If the observation contains variables not defined in the scope of the circuit.
NotImplementedError: If the evidence of a multivariate input layer needs to be constructed.
"""
if not all(
(isinstance(value, Number) or len(value) == 1)
Expand Down Expand Up @@ -221,7 +222,7 @@ def integrate(
scope = sc.scope
if not scope:
raise ValueError("There are no variables to integrate over")
elif not scope <= sc.scope:
if not scope <= sc.scope:
raise ValueError(
"The variables scope to integrate must be a subset of the scope of the circuit"
)
Expand Down Expand Up @@ -328,7 +329,7 @@ def multiply(sc1: Circuit, sc2: Circuit, *, registry: OperatorRegistry | None =

# Check whether we are multiplying layers over disjoint scope
# If that is the case, then we just need to introduce a Kronecker product layer
if len(sc1.layer_scope(l1) & sc2.layer_scope(l2)) == 0:
if not sc1.layer_scope(l1) & sc2.layer_scope(l2):
if l1.num_output_units != l2.num_output_units:
raise NotImplementedError(
f"Layers over disjoint scopes can be multiplied if they have the same size, "
Expand Down Expand Up @@ -608,7 +609,7 @@ def differentiate(
operation=CircuitOperation(
operator=CircuitOperator.DIFFERENTIATION,
operands=(sc,),
metadata=dict(order=order),
metadata={"order": order},
),
)

Expand Down
6 changes: 3 additions & 3 deletions cirkit/symbolic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,19 @@ def __init__(self, initializer: Initializer, fill_value: float = 0.0):
self._fill_value = fill_value

@property
def initializer() -> Initializer:
def initializer(self) -> Initializer:
return self._initializer

@property
def fill_value() -> float:
def fill_value(self) -> float:
return self._fill_value

@property
def config(self) -> dict[str, Any]:
return {"initializer": self._initializer, "fill_value": self._fill_value}

def allows_shape(self, shape: tuple[int, ...]) -> bool:
if len(shape) != 2 or shape[1] % shape[0] != 0:
if len(shape) != 2 or shape[1] % shape[0]:
return False
mixing_weights_shape = (shape[0], shape[1] // shape[0])
return self._initializer.allows_shape(mixing_weights_shape)
25 changes: 25 additions & 0 deletions cirkit/symbolic/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ def copyref(self) -> "Layer":
ref_params = {pname: pgraph.ref() for pname, pgraph in self.params.items()}
return type(self)(**self.config, **ref_params)

def __repr__(self) -> str:
config_repr = ", ".join(f"{k}={v}" for k, v in self.config.items())
params_repr = ", ".join(f"{k}={v}" for k, v in self.params.items())
return (
f"{self.__class__.__name__}("
f"num_input_units={self.num_input_units}, "
f"num_output_units={self.num_output_units}, "
f"arity={self.arity}, "
f"config=({config_repr}), "
f"params=({params_repr})"
)


class InputLayer(Layer, ABC):
"""The symbolic input layer class."""
Expand Down Expand Up @@ -141,6 +153,19 @@ def num_channels(self) -> int:
"""
return self.arity

def __repr__(self) -> str:
config_repr = ", ".join(f"{k}={v}" for k, v in self.config.items())
params_repr = ", ".join(f"{k}={v}" for k, v in self.params.items())
return (
f"{self.__class__.__name__}("
f"scope={self.scope}, "
f"num_channels={self.arity}, "
f"num_output_units={self.num_output_units}, "
f"config=({config_repr})"
f"params=({params_repr})"
")"
)


class ConstantLayer(InputLayer, ABC):
"""The symbolic layer computing a constant vector, i.e., it does not depend on any variable."""
Expand Down
13 changes: 7 additions & 6 deletions cirkit/symbolic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,25 @@ def multiply_gaussian_layers(sl1: GaussianLayer, sl2: GaussianLayer) -> CircuitB
f"but found '{sl1.num_channels}' and '{sl2.num_channels}'"
)

gaussian1_shape, gaussian2_shape = sl1.mean.shape, sl2.mean.shape
mean = Parameter.from_nary(
GaussianProductMean(gaussian1_shape, gaussian2_shape),
GaussianProductMean(sl1.mean.shape, sl1.stddev.shape, sl2.mean.shape, sl2.stddev.shape),
sl1.mean.ref(),
sl2.mean.ref(),
sl1.stddev.ref(),
sl2.mean.ref(),
sl2.stddev.ref(),
)
stddev = Parameter.from_binary(
GaussianProductStddev(gaussian1_shape, gaussian2_shape),
GaussianProductStddev(sl1.stddev.shape, sl2.stddev.shape),
sl1.stddev.ref(),
sl2.stddev.ref(),
)
log_partition = Parameter.from_nary(
GaussianProductLogPartition(gaussian1_shape, gaussian2_shape),
GaussianProductLogPartition(
sl1.mean.shape, sl1.stddev.shape, sl2.mean.shape, sl2.stddev.shape
),
sl1.mean.ref(),
sl2.mean.ref(),
sl1.stddev.ref(),
sl2.mean.ref(),
sl2.stddev.ref(),
)

Expand Down
Loading

0 comments on commit f096e9a

Please sign in to comment.