-
Notifications
You must be signed in to change notification settings - Fork 1
Circuit extraction from GraphState #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mf
Are you sure you want to change the base?
Changes from all commits
abb9fa4
5ef6d90
ecd81a0
d15f954
912cc8a
eb6f70e
260577a
0777029
735b32d
85a79bb
d8d07e6
31fc41d
6d196b5
7e7a73d
fdf0aa4
9a7c3a7
60593b5
d8520fe
eab5787
0c25e60
39cb744
6cb9263
0e7357a
c638d4f
1ec34c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| """Basic example of simplifying a measurement pattern via a ZX-diagram simplification. | ||
|
|
||
| By using the `prune_non_cliffords` method, | ||
| we can remove certain Clifford nodes and non-Clifford nodes from the ZX-diagram, | ||
| which can simplify the resulting measurement pattern. | ||
| """ | ||
|
|
||
| # %% | ||
| from copy import deepcopy | ||
|
|
||
| import numpy as np | ||
|
|
||
| from graphix_zx.circuit import circuit2graph | ||
| from graphix_zx.random_objects import get_random_gflow_circ | ||
| from graphix_zx.visualizer import visualize | ||
| from graphix_zx.zxgraphstate import ZXGraphState | ||
|
|
||
| # %% | ||
| circ = get_random_gflow_circ(4, 4, angle_list=[0, np.pi / 3, 2 * np.pi / 3, np.pi]) | ||
| graph, flow = circuit2graph(circ) | ||
| zx_graph = ZXGraphState() | ||
| zx_graph.append(graph) | ||
|
|
||
| visualize(zx_graph) | ||
| print("node | plane | angle (/pi)") | ||
| for node in zx_graph.input_nodes: | ||
| print(f"{node} (input)", zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi) | ||
| for node in zx_graph.physical_nodes - zx_graph.input_nodes - zx_graph.output_nodes: | ||
| print(node, zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi) | ||
|
|
||
| # %% | ||
| zx_graph_smp = deepcopy(zx_graph) | ||
| zx_graph_smp.prune_non_cliffords() | ||
|
|
||
| visualize(zx_graph_smp) | ||
| print("node | plane | angle (/pi)") | ||
| for node in zx_graph.input_nodes: | ||
| print(f"{node} (input)", zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi) | ||
| for node in zx_graph_smp.physical_nodes - zx_graph.input_nodes - zx_graph_smp.output_nodes: | ||
| print(node, zx_graph_smp.meas_bases[node].plane, zx_graph_smp.meas_bases[node].angle / np.pi) | ||
|
|
||
| # %% |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| from typing import Set, List, Tuple | ||
|
|
||
| import numpy as np | ||
|
|
||
| from graphix_zx.zxgraphstate import ZXGraphState | ||
| from graphix_zx.circuit import MBQCCircuit, circuit2graph | ||
| from graphix_zx.gates import CZ, CNOT, H, PhaseGadget | ||
|
|
||
|
|
||
| def extract_circuit_graph_state(d: ZXGraphState) -> ZXGraphState: | ||
| """ | ||
| Extract a circuit from a ZXGraphState in MBQC+LC form with gflow and return | ||
| the corresponding circuit-like graph state (as ZXGraphState). | ||
| Implements Algorithm 2 (Circuit Extraction) from "There and back again" (Appendix D). | ||
| """ | ||
| # Step 0: Bring into phase-gadget form | ||
| d.convert_to_phase_gadget() | ||
|
|
||
| # Initialize circuit and frontier | ||
| inputs = sorted(d.input_nodes) | ||
| n_qubits = len(inputs) | ||
| circ = MBQCCircuit(n_qubits) | ||
| frontier: Set[int] = set(d.output_nodes) | ||
|
|
||
| # Process initial frontier | ||
| _process_frontier(d, frontier, circ) | ||
|
|
||
| # Main extraction loop | ||
| while True: | ||
| # Remaining unextracted vertices | ||
| remaining = set(d.physical_nodes) - frontier | ||
| if not remaining: | ||
| break | ||
| _update_frontier(d, frontier, circ) | ||
|
|
||
| # Final SWAP/H corrections if needed | ||
| _finalize_extraction(d, frontier, circ) | ||
| _revert_circuit(d, circ) | ||
|
|
||
| # Convert MBQCCircuit back to ZXGraphState | ||
| # graph, _ = circuit2graph(circ) | ||
| # return graph | ||
|
|
||
| return circ | ||
|
|
||
|
|
||
| def _process_frontier(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None: | ||
| """ | ||
| Process the frontier: extract local Cliffords and CZ between frontier vertices. | ||
| """ | ||
| lc = d.local_clifford | ||
| for v in sorted(frontier): | ||
| # Extract any local Clifford on v | ||
| if v in lc.keys(): | ||
| # to be implemented: add local Clifford gates | ||
| pass | ||
| # Extract any CZ edges between frontier vertices | ||
| for w in list(d.get_neighbors(v) & frontier): | ||
| circ.cz(v, w) | ||
| d.remove_physical_edge(v, w) | ||
|
|
||
|
|
||
| def _update_frontier(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None: | ||
| """ | ||
| Update the frontier by Gaussian elimination or pivots, then extract new frontier vertices. | ||
| """ | ||
| # Build bipartite adjacency: frontier vs neighbors | ||
| neigh = sorted(set().union(*(d.get_neighbors(v) for v in frontier))) | ||
| M = np.zeros((len(frontier), len(neigh)), dtype=int) | ||
| for i, v in enumerate(sorted(frontier)): | ||
| for j, u in enumerate(neigh): | ||
| if u in d.get_neighbors(v): | ||
| M[i, j] = 1 | ||
| # Gaussian eliminate over GF(2) | ||
| M_red, row_ops = _gauss_elim(M) | ||
| # Identify rows with single 1 | ||
| vs: List[int] = [] | ||
| for i, row in enumerate(M_red): | ||
| if row.sum() == 1: | ||
| col = int(np.nonzero(row)[0][0]) | ||
| vs.append(neigh[col]) | ||
|
|
||
| if not vs: | ||
| # Step 4: pivot YZ vertices adjacent to frontier | ||
| # to be implemented | ||
| pass | ||
| # for u in list(d.physical_nodes - frontier): | ||
| # if d.meas_bases[u].plane.name == 'YZ' and d.get_neighbors(u) & frontier: | ||
| # w = next(iter(d.get_neighbors(u) & frontier)) | ||
| # d.pivot(u, w) | ||
| # _process_frontier(d, frontier, circ) | ||
| # return | ||
|
|
||
| # Apply recorded CNOT row operations | ||
| for r1, r2 in row_ops: | ||
| circ.cnot(sorted(frontier)[r1], sorted(frontier)[r2]) # CNOT is not implemented in MBQCCircuit | ||
| # Update graph accordingly: add edge or local complement as needed | ||
| d.apply_cnot(sorted(frontier)[r1], sorted(frontier)[r2]) | ||
|
|
||
| # Extract new frontier vertices | ||
| for v in vs: | ||
| # unique neighbor in frontier | ||
| w = next(iter(d.get_neighbors(v) & frontier)) | ||
| circ.add_gate(H(), [w]) | ||
| circ.add_gate(PhaseGadget(d.meas_bases[v].angle), [w]) | ||
| frontier.remove(w) | ||
| frontier.add(v) | ||
| _process_frontier(d, frontier, circ) | ||
|
|
||
|
|
||
| def _gauss_elim(M: np.ndarray) -> Tuple[np.ndarray, List[Tuple[int,int]]]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please respect PEP8. |
||
| """ | ||
| Perform Gaussian elimination over GF(2), returning reduced matrix and list of row ops. | ||
| """ | ||
| M = M.copy() % 2 | ||
| n, m = M.shape | ||
| row_ops: List[Tuple[int,int]] = [] | ||
| pivot_row = 0 | ||
| for col in range(m): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too complex. |
||
| # find pivot | ||
| for r in range(pivot_row, n): | ||
| if M[r, col] == 1: | ||
| M[[pivot_row, r]] = M[[r, pivot_row]] | ||
| if r != pivot_row: | ||
| row_ops.append((pivot_row, r)) | ||
| break | ||
| else: | ||
| continue | ||
| # eliminate other rows | ||
| for r in range(n): | ||
| if r != pivot_row and M[r, col] == 1: | ||
| M[r] ^= M[pivot_row] | ||
| row_ops.append((r, pivot_row)) | ||
| pivot_row += 1 | ||
| if pivot_row == n: | ||
| break | ||
| return M, row_ops | ||
|
|
||
|
|
||
| def _finalize_extraction(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None: | ||
| """ | ||
| Extract final Hadamards or SWAPs to align frontier to inputs. | ||
| """ | ||
| # to be implemented | ||
|
|
||
| # # Handle any remaining Hadamard on outputs | ||
| # for v in sorted(frontier): | ||
| # if d.has_hadamard_on_output(v): | ||
| # circ.add_gate(H(), [v]) | ||
| # # Permute frontier to match inputs via SWAPs | ||
| # perm = d.compute_permutation(list(frontier), list(d.input_nodes)) | ||
| # for (q1, q2) in perm: | ||
| # circ.add_gate(CNOT(), [q1, q2]) # SWAP as three CNOTs omitted for brevity | ||
|
|
||
| def _revert_circuit(d: ZXGraphState, circ: MBQCCircuit) -> None: | ||
| """ | ||
| Revert the circuit. | ||
| """ | ||
| # to be implemented | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -400,6 +400,28 @@ def local_cliffords(self) -> dict[int, LocalClifford]: | |
| """ | ||
| return self.__local_cliffords | ||
|
|
||
| @property | ||
| def inner2nodes(self) -> dict[int, int]: | ||
| """Return inner index to node index mapping. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[int, int] | ||
| inner index to node index mapping. | ||
| """ | ||
| return self.__inner2nodes | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it OK to expose internal states? |
||
|
|
||
| @property | ||
| def nodes2inner(self) -> dict[int, int]: | ||
| """Return node index to inner index mapping. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[int, int] | ||
| node index to inner index mapping. | ||
| """ | ||
| return self.__nodes2inner | ||
|
|
||
| def check_meas_basis(self) -> None: | ||
| """Check if the measurement basis is set for all physical nodes except output nodes. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
|
|
||
| import numpy as np | ||
|
|
||
| from graphix_zx.circuit import MBQCCircuit | ||
| from graphix_zx.common import default_meas_basis | ||
| from graphix_zx.graphstate import GraphState | ||
|
|
||
|
|
@@ -82,3 +83,50 @@ def get_random_flow_graph( | |
| num_nodes += 1 | ||
|
|
||
| return graph, flow | ||
|
|
||
|
|
||
| def get_random_gflow_circ( | ||
| width: int, | ||
| depth: int, | ||
| rng: np.random.Generator | None = None, | ||
| edge_p: float = 0.5, | ||
| angle_list: list | None = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use an appropriate
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) -> MBQCCircuit: | ||
| """Generate a random MBQC circuit which has gflow. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| width : int | ||
| circuit width | ||
| depth : int | ||
| circuit depth | ||
| rng : np.random.Generator, optional | ||
| random number generator, by default np.random.default_rng() | ||
| edge_p : float, optional | ||
| probability of adding CZ gate, by default 0.5 | ||
| angle_list : list, optional | ||
| list of angles, by default [0, np.pi / 3, 2 * np.pi / 3, np.pi] | ||
|
|
||
| Returns | ||
| ------- | ||
| MBQCCircuit | ||
| generated MBQC circuit | ||
| """ | ||
| if rng is None: | ||
| rng = np.random.default_rng() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MEMO: it's better to cache |
||
| if angle_list is None: | ||
| angle_list = [0, np.pi / 3, 2 * np.pi / 3, np.pi] | ||
| circ = MBQCCircuit(width) | ||
| for d in range(depth): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit complex. |
||
| for j in range(width): | ||
| circ.j(j, rng.choice(angle_list)) | ||
| if d < depth - 1: | ||
| for j in range(width): | ||
| if rng.random() < edge_p: | ||
| circ.cz(j, (j + 1) % width) | ||
| num = rng.integers(0, width) | ||
| if num > 0: | ||
| target = set(rng.choice(list(range(width)), num)) | ||
| circ.phase_gadget(target, rng.choice(angle_list)) | ||
|
|
||
| return circ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deprecated.