Skip to content

Commit

Permalink
feature: support non-contiguous qubit indices local simulator (#262)
Browse files Browse the repository at this point in the history
* feature: support non-contiguous qubit indices local simulator

* Fix lint

* Convert to classmethod

* Fix black

* Map IR inplace

* Remove unused import

* Remove qubit set sorted assumption

* Remove test

* Support Jaqcd discontiguous qubits

* Add discontiguous qubits test

* Add discontiguous qubits test

* Refactor to fixture

* Fix discontiguous qubits targets mapping

* Fix lint

* fix lint

* Remove check for discontiguous qubit

* Remove discontiguous qubits check

* Make private

* Remove redudant basis rotation prop check

* Add OpenQASM discontiguous qubits test

* Refactor to static

* Make code understable

* Improve coverage

* Fix Jaqcd result mapping

* Handle Jaqcd result instructions target

* Handle invalid qubit

* Add test for targets

* Remove unnecessary mapping

* Apply suggestions from code review

---------

Co-authored-by: Cody Wang <speller26@gmail.com>
  • Loading branch information
WingCode and speller26 authored Jun 24, 2024
1 parent 57b2c54 commit 8f3463a
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/braket/default_simulator/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
],
"supportPhysicalQubits": False,
"supportsPartialVerbatimBox": False,
"requiresContiguousQubitIndices": True,
"requiresContiguousQubitIndices": False,
"requiresAllQubitsMeasurement": False,
"supportsUnassignedMeasurements": True,
"disabledQubitRewiringSupported": False,
Expand Down
164 changes: 153 additions & 11 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from braket.device_schema import DeviceActionType
from braket.ir.jaqcd import Program as JaqcdProgram
from braket.ir.jaqcd.program_v1 import Results
from braket.ir.jaqcd.shared_models import MultiTarget, OptionalMultiTarget
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import (
AdditionalMetadata,
Expand Down Expand Up @@ -273,13 +274,8 @@ def _create_results_obj(
)

@staticmethod
def _validate_operation_qubits(operations: list[Operation]) -> None:
qubits_referenced = {target for operation in operations for target in operation.targets}
if qubits_referenced and max(qubits_referenced) >= len(qubits_referenced):
raise ValueError(
"Non-contiguous qubit indices supplied; "
"qubit indices in a circuit must be contiguous."
)
def _get_qubits_referenced(operations: list[Operation]) -> set[int]:
return {target for operation in operations for target in operation.targets}

@staticmethod
def _validate_result_types_qubits_exist(
Expand Down Expand Up @@ -386,6 +382,151 @@ def _observable_hash(observable: Observable) -> Union[str, dict[int, str]]:
else:
return str(observable.__class__.__name__)

@staticmethod
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> Circuit:
"""
Maps the qubits in operations and result types to contiguous qubits.
Args:
circuit (Union[Circuit, JaqcdProgram]): The circuit containing the operations and
result types.
Returns:
Circuit: The circuit with qubits in operations and result types mapped
to contiguous qubits.
"""
circuit_qubit_set = BaseLocalSimulator._get_circuit_qubit_set(circuit)
qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(circuit_qubit_set)
BaseLocalSimulator._map_instructions_to_qubits(circuit, qubit_map)
return circuit

@staticmethod
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set:
"""
Returns the set of qubits used in the given circuit.
Args:
circuit (Union[Circuit, JaqcdProgram]): The circuit from which to extract the qubit set.
Returns:
set: The set of qubits used in the circuit.
"""
if isinstance(circuit, Circuit):
return circuit.qubit_set
else:
operations = [
from_braket_instruction(instruction) for instruction in circuit.instructions
]
if circuit.basis_rotation_instructions:
operations.extend(
from_braket_instruction(instruction)
for instruction in circuit.basis_rotation_instructions
)
return BaseLocalSimulator._get_qubits_referenced(operations)

@staticmethod
def _map_instructions_to_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict):
"""
Maps the qubits in operations and result types to contiguous qubits.
Args:
circuit (Circuit): The circuit containing the operations and result types.
Returns:
Circuit: The circuit with qubits in operations and result types mapped
to contiguous qubits.
"""
if isinstance(circuit, Circuit):
BaseLocalSimulator._map_circuit_instructions(circuit, qubit_map)
BaseLocalSimulator._map_circuit_results(circuit, qubit_map)
else:
BaseLocalSimulator._map_jaqcd_instructions(circuit, qubit_map)

return circuit

@staticmethod
def _map_circuit_instructions(circuit: Circuit, qubit_map: dict):
"""
Maps the targets of each instruction in the circuit to the corresponding qubits in the
qubit_map.
Args:
circuit (Circuit): The circuit containing the instructions.
qubit_map (dict): A dictionary mapping original qubits to new qubits.
"""
for ins in circuit.instructions:
ins._targets = tuple([qubit_map[q] for q in ins.targets])

@staticmethod
def _map_circuit_results(circuit: Circuit, qubit_map: dict):
"""
Maps the targets of each result in the circuit to the corresponding qubits in the qubit_map.
Args:
circuit (Circuit): The circuit containing the results.
qubit_map (dict): A dictionary mapping original qubits to new qubits.
"""
for result in circuit.results:
if isinstance(result, (MultiTarget, OptionalMultiTarget)) and result.targets:
result.targets = [qubit_map[q] for q in result.targets]

@staticmethod
def _map_jaqcd_instructions(circuit: JaqcdProgram, qubit_map: dict):
"""
Maps the attributes of each instruction in the JaqcdProgram to the corresponding qubits in
the qubit_map.
Args:
circuit (JaqcdProgram): The JaqcdProgram containing the instructions.
qubit_map (dict): A dictionary mapping original qubits to new qubits.
"""
for ins in circuit.instructions:
BaseLocalSimulator._map_instruction_attributes(ins, qubit_map)

if hasattr(circuit, "results") and circuit.results:
for ins in circuit.results:
BaseLocalSimulator._map_instruction_attributes(ins, qubit_map)

if circuit.basis_rotation_instructions:
for ins in circuit.basis_rotation_instructions:
ins.target = qubit_map[ins.target]

@staticmethod
def _map_instruction_attributes(instruction, qubit_map: dict):
"""
Maps the qubit attributes of an instruction from JaqcdProgram to the corresponding
qubits in the qubit_map.
Args:
instruction: The Jaqcd instruction whose qubit attributes need to be mapped.
qubit_map (dict): A dictionary mapping original qubits to new qubits.
"""
if hasattr(instruction, "control"):
instruction.control = qubit_map.get(instruction.control, instruction.control)

if hasattr(instruction, "controls") and instruction.controls:
instruction.controls = [qubit_map.get(q, q) for q in instruction.controls]

if hasattr(instruction, "target"):
instruction.target = qubit_map.get(instruction.target, instruction.target)

if hasattr(instruction, "targets") and instruction.targets:
instruction.targets = [qubit_map.get(q, q) for q in instruction.targets]

@staticmethod
def _contiguous_qubit_mapping(qubit_set: list[int]) -> dict[int, int]:
"""
Maping of qubits to contiguous integers. The qubit mapping may be discontiguous or
contiguous.
Args:
qubit_set (list[int]): List of qubits to be mapped.
Returns:
dict[int, int]: Dictionary where keys are qubits and values are contiguous integers.
"""
return {q: i for i, q in enumerate(sorted(qubit_set))}

@staticmethod
def _formatted_measurements(
simulation: Simulation, measured_qubits: Union[list[int], None] = None
Expand Down Expand Up @@ -418,6 +559,7 @@ def _formatted_measurements(
measurements = np.pad(
selected_measurements, ((0, 0), (0, len(measured_qubits_not_in_circuit)))
).tolist()

else:
measurements = np.zeros(
(simulation.shots, len(measured_qubits)), dtype=int
Expand Down Expand Up @@ -465,14 +607,14 @@ def run_openqasm(
self._validate_input_provided(circuit)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit.results, qubit_count)

operations = circuit.instructions
BaseLocalSimulator._validate_operation_qubits(operations)
circuit = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)

results = circuit.results

simulation = self.initialize_simulation(
qubit_count=qubit_count, shots=shots, batch_size=batch_size
)
operations = circuit.instructions
simulation.evolve(operations)

if not shots:
Expand Down Expand Up @@ -533,6 +675,8 @@ def run_jaqcd(
)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)

circuit_ir = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)

operations = [
from_braket_instruction(instruction) for instruction in circuit_ir.instructions
]
Expand All @@ -541,8 +685,6 @@ def run_jaqcd(
for instruction in circuit_ir.basis_rotation_instructions:
operations.append(from_braket_instruction(instruction))

BaseLocalSimulator._validate_operation_qubits(operations)

simulation = self.initialize_simulation(
qubit_count=qubit_count, shots=shots, batch_size=batch_size
)
Expand Down
2 changes: 1 addition & 1 deletion src/braket/default_simulator/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
],
"supportPhysicalQubits": False,
"supportsPartialVerbatimBox": False,
"requiresContiguousQubitIndices": True,
"requiresContiguousQubitIndices": False,
"requiresAllQubitsMeasurement": False,
"supportsUnassignedMeasurements": True,
"disabledQubitRewiringSupported": False,
Expand Down
6 changes: 6 additions & 0 deletions test/resources/discontiguous.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
OPENQASM 3.0;
bit[2] b;
qubit[5] q;
h q[2];
cnot q[2], q[3];
b = measure q;
9 changes: 9 additions & 0 deletions test/resources/discontiguous_jaqcd.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"braketSchemaHeader": {"name": "braket.ir.jaqcd.program", "version": "1"},
"instructions": [
{"target": 2, "type": "x"},
{"control": 2, "target": 9, "type": "cnot"}
],
"results": [],
"basis_rotation_instructions": []
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def grcs_8_qubit(ir_type):
return CircuitData(OpenQASMProgram(source="test/resources/grcs_8.qasm"), 0.0007324)


@pytest.fixture
def discontiguous_jaqcd():
with open("test/resources/discontiguous_jaqcd.json") as jaqcd_definition:
data = json.load(jaqcd_definition)
return json.dumps(data)


@pytest.fixture
def discontiguous_qasm():
return OpenQASMProgram(source="test/resources/discontiguous.qasm")


@pytest.fixture
def bell_ir(ir_type):
return (
Expand Down Expand Up @@ -302,7 +314,7 @@ def test_properties():
],
"supportPhysicalQubits": False,
"supportsPartialVerbatimBox": False,
"requiresContiguousQubitIndices": True,
"requiresContiguousQubitIndices": False,
"requiresAllQubitsMeasurement": False,
"supportsUnassignedMeasurements": True,
"disabledQubitRewiringSupported": False,
Expand Down Expand Up @@ -831,3 +843,20 @@ def test_measure_with_qubits_not_used():
assert np.sum(measurements, axis=0)[3] == 0
assert len(measurements[0]) == 4
assert result.measuredQubits == [0, 1, 2, 3]


def test_discontiguous_qubits_jaqcd(discontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(discontiguous_jaqcd)
result = DensityMatrixSimulator().run(prg, qubit_count=2, shots=1)

assert result.measuredQubits == [0, 1]
assert result.measurements == [["1", "1"]]


def test_discontiguous_qubits_openqasm(discontiguous_qasm):
simulator = DensityMatrixSimulator()
result = simulator.run(discontiguous_qasm, shots=1000)

measurements = np.array(result.measurements, dtype=int)
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def grcs_16_qubit(ir_type):
return CircuitData(OpenQASMProgram(source="test/resources/grcs_16.qasm"), 0.0000062)


@pytest.fixture
def discontiguous_jaqcd():
with open("test/resources/discontiguous_jaqcd.json") as jaqcd_definition:
data = json.load(jaqcd_definition)
return json.dumps(data)


@pytest.fixture
def discontiguous_qasm():
return OpenQASMProgram(source="test/resources/discontiguous.qasm")


@pytest.fixture
def bell_ir(ir_type):
return (
Expand Down Expand Up @@ -240,7 +252,7 @@ def test_properties():
],
"supportPhysicalQubits": False,
"supportsPartialVerbatimBox": False,
"requiresContiguousQubitIndices": True,
"requiresContiguousQubitIndices": False,
"requiresAllQubitsMeasurement": False,
"supportsUnassignedMeasurements": True,
"disabledQubitRewiringSupported": False,
Expand Down Expand Up @@ -900,7 +912,6 @@ def test_simulator_run_result_types_shots_basis_rotation_gates_value_error():
),
],
)
@pytest.mark.xfail(raises=ValueError)
def test_simulator_run_non_contiguous_qubits(ir, qubit_count):
# not relevant for openqasm, since it handles qubit allocation
simulator = StateVectorSimulator()
Expand Down Expand Up @@ -1363,3 +1374,49 @@ def test_rotation_parameter_expressions(operation, state_vector):
result = simulator.run(OpenQASMProgram(source=qasm), shots=0)
assert result.resultTypes[0].type == StateVector()
assert np.allclose(result.resultTypes[0].value, np.array(state_vector))


def test_discontiguous_qubits_jaqcd(discontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(discontiguous_jaqcd)
result = StateVectorSimulator().run(prg, qubit_count=2, shots=1)

assert result.measuredQubits == [0, 1]
assert result.measurements == [["1", "1"]]


def test_discontiguous_qubits_openqasm(discontiguous_qasm):
simulator = StateVectorSimulator()
result = simulator.run(discontiguous_qasm, shots=1000)

measurements = np.array(result.measurements, dtype=int)
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]


def test_discontiguous_qubits_jaqcd_multiple_controls():
jaqcd_program = {
"braketSchemaHeader": {"name": "braket.ir.jaqcd.program", "version": "1"},
"instructions": [
{"type": "x", "target": 3},
{"type": "x", "target": 4},
{"type": "ccnot", "controls": [3, 4], "target": 5},
],
}
prg = JaqcdProgram.parse_raw(json.dumps(jaqcd_program))
result = StateVectorSimulator().run(prg, qubit_count=3, shots=1)

assert result.measuredQubits == [0, 1, 2]
assert result.measurements == [["1", "1", "1"]]


def test_discontiguous_qubits_jaqcd_multiple_targets():
jaqcd_program = {
"braketSchemaHeader": {"name": "braket.ir.jaqcd.program", "version": "1"},
"instructions": [{"type": "x", "target": 3}, {"type": "swap", "targets": [3, 4]}],
"results": [{"type": "expectation", "observable": ["z"], "targets": [4]}],
}
prg = JaqcdProgram.parse_raw(json.dumps(jaqcd_program))
result = StateVectorSimulator().run(prg, qubit_count=2, shots=0)

assert result.measuredQubits == [0, 1]
assert result.resultTypes[0].value == -1

0 comments on commit 8f3463a

Please sign in to comment.