Skip to content

Commit

Permalink
feat: optimize performance for simulating control modifiers (#153)
Browse files Browse the repository at this point in the history
Co-authored-by: Cody Wang <speller26@gmail.com>
  • Loading branch information
ajberdy and speller26 authored Apr 18, 2023
1 parent da005c9 commit 046e343
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 32 deletions.
70 changes: 46 additions & 24 deletions src/braket/default_simulator/linalg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,53 @@

import numpy as np

_SLICES = (
_NEG_CONTROL_SLICE := slice(None, 1),
_CONTROL_SLICE := slice(1, None),
_NO_CONTROL_SLICE := slice(None, None),
)


def multiply_matrix(
state: np.ndarray,
matrix: np.ndarray,
targets: Tuple[int, ...],
controls: Optional[Tuple[int]] = (),
control_state: Optional[Tuple[int]] = (),
) -> np.ndarray:
"""Multiplies the given matrix by the given state, applying the matrix on the target qubits,
controlling the operation as specified.
Args:
state (np.ndarray): The state to multiply the matrix by.
matrix (np.ndarray): The matrix to apply to the state.
targets (Tuple[int]): The qubits to apply the state on.
controls (Optional[Tuple[int]]): The qubits to control the operation on. Default ().
control_state (Optional[Tuple[int]]): A tuple of same length as `controls` with either
a 0 or 1 in each index, corresponding to whether to control on the |0⟩ or |1⟩ state.
Default (1,) * len(controls).
Returns:
np.ndarray: The state after the matrix has been applied.
"""
if not controls:
return _multiply_matrix(state, matrix, targets)

control_state = control_state or (1,) * len(controls)
num_qubits = len(state.shape)
control_slices = {i: _SLICES[state] for i, state in zip(controls, control_state)}
ctrl_index = tuple(
control_slices[i] if i in controls else _NO_CONTROL_SLICE for i in range(num_qubits)
)
state[ctrl_index] = _multiply_matrix(state[ctrl_index], matrix, targets)
return state

def multiply_matrix(state: np.ndarray, matrix: np.ndarray, targets: Tuple[int, ...]) -> np.ndarray:

def _multiply_matrix(
state: np.ndarray,
matrix: np.ndarray,
targets: Tuple[int, ...],
) -> np.ndarray:
"""Multiplies the given matrix by the given state, applying the matrix on the target qubits.
Args:
Expand Down Expand Up @@ -119,26 +164,3 @@ def _get_target_permutation(targets: Sequence[int]) -> Sequence[int]:
return np.ravel_multi_index(
basis_states[:, np.argsort(np.argsort(targets))].T, [2] * len(targets)
)


def controlled_unitary(unitary: np.ndarray, negctrl: bool = False) -> np.ndarray:
"""
Transform unitary matrix into a controlled unitary matrix.
Args:
unitary (np.ndarray): Unitary matrix operation.
negctrl (bool): Whether to control the operation on the |0⟩ state,
instead of the |1⟩ state. Default: False.
Returns:
np.ndarray: A controlled version of the provided unitary matrix.
"""
upper_left, bottom_right = np.eye(unitary.shape[0]), unitary
if negctrl:
upper_left, bottom_right = bottom_right, upper_left
return np.block(
[
[upper_left, np.zeros_like(unitary)],
[np.zeros_like(unitary), bottom_right],
]
)
6 changes: 1 addition & 5 deletions src/braket/default_simulator/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import numpy as np
from scipy.linalg import fractional_matrix_power

from braket.default_simulator.linalg_utils import controlled_unitary


class Operation(ABC):
"""
Expand Down Expand Up @@ -50,6 +48,7 @@ def __init__(self, targets, *params, ctrl_modifiers=(), power=1):
self._ctrl_modifiers = ctrl_modifiers
self._power = power

@property
@abstractmethod
def _base_matrix(self) -> np.ndarray:
"""np.ndarray: The matrix representation of the operation."""
Expand All @@ -61,9 +60,6 @@ def matrix(self) -> np.ndarray:
unitary = np.linalg.matrix_power(unitary, int(self._power))
else:
unitary = fractional_matrix_power(unitary, self._power)

for mod in self._ctrl_modifiers:
unitary = controlled_unitary(unitary, negctrl=mod)
return unitary

def __eq__(self, other):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def apply_operations(
"""
for operation in operations:
matrix = operation.matrix
targets = operation.targets
state = multiply_matrix(state, matrix, targets)
all_targets = operation.targets
num_ctrl = len(operation._ctrl_modifiers)
control_state = tuple(np.logical_not(operation._ctrl_modifiers).astype(int))
controls = all_targets[:num_ctrl]
targets = all_targets[num_ctrl:]
state = multiply_matrix(state, matrix, targets, controls, control_state)
return state
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,6 @@ def test_gphase():
circuit = Interpreter().build_circuit(qasm)
simulation = StateVectorSimulation(2, 1, 1)
simulation.evolve(circuit.instructions)
print(simulation.state_vector)
assert np.allclose(simulation.state_vector, [-1 / np.sqrt(2), 0, 0, 1 / np.sqrt(2)])


Expand Down

0 comments on commit 046e343

Please sign in to comment.