Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/measurement_pattern_simplification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Basic example of simplifying a measurement pattern via a ZX-diagram simplification.

By using the `prune_non_cliffords` method,
we can remove certain Clifford nodes and non-Clifford nodes from the ZX-diagram,
which can simplify the resulting measurement pattern.
"""

# %%
from copy import deepcopy

import numpy as np

from graphix_zx.circuit import circuit2graph
from graphix_zx.random_objects import get_random_gflow_circ
from graphix_zx.visualizer import visualize
from graphix_zx.zxgraphstate import ZXGraphState

# %%
circ = get_random_gflow_circ(4, 4, angle_list=[0, np.pi / 3, 2 * np.pi / 3, np.pi])
graph, flow = circuit2graph(circ)
zx_graph = ZXGraphState()
zx_graph.append(graph)

visualize(zx_graph)
print("node | plane | angle (/pi)")
for node in zx_graph.input_nodes:
print(f"{node} (input)", zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi)
for node in zx_graph.physical_nodes - zx_graph.input_nodes - zx_graph.output_nodes:
print(node, zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi)

# %%
zx_graph_smp = deepcopy(zx_graph)
zx_graph_smp.prune_non_cliffords()

visualize(zx_graph_smp)
print("node | plane | angle (/pi)")
for node in zx_graph.input_nodes:
print(f"{node} (input)", zx_graph.meas_bases[node].plane, zx_graph.meas_bases[node].angle / np.pi)
for node in zx_graph_smp.physical_nodes - zx_graph.input_nodes - zx_graph_smp.output_nodes:
print(node, zx_graph_smp.meas_bases[node].plane, zx_graph_smp.meas_bases[node].angle / np.pi)

# %%
161 changes: 161 additions & 0 deletions graphix_zx/circuit_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Set, List, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecated.


import numpy as np

from graphix_zx.zxgraphstate import ZXGraphState
from graphix_zx.circuit import MBQCCircuit, circuit2graph
from graphix_zx.gates import CZ, CNOT, H, PhaseGadget


def extract_circuit_graph_state(d: ZXGraphState) -> ZXGraphState:
"""
Extract a circuit from a ZXGraphState in MBQC+LC form with gflow and return
the corresponding circuit-like graph state (as ZXGraphState).
Implements Algorithm 2 (Circuit Extraction) from "There and back again" (Appendix D).
"""
# Step 0: Bring into phase-gadget form
d.convert_to_phase_gadget()

# Initialize circuit and frontier
inputs = sorted(d.input_nodes)
n_qubits = len(inputs)
circ = MBQCCircuit(n_qubits)
frontier: Set[int] = set(d.output_nodes)

# Process initial frontier
_process_frontier(d, frontier, circ)

# Main extraction loop
while True:
# Remaining unextracted vertices
remaining = set(d.physical_nodes) - frontier
if not remaining:
break
_update_frontier(d, frontier, circ)

# Final SWAP/H corrections if needed
_finalize_extraction(d, frontier, circ)
_revert_circuit(d, circ)

# Convert MBQCCircuit back to ZXGraphState
# graph, _ = circuit2graph(circ)
# return graph

return circ


def _process_frontier(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None:
"""
Process the frontier: extract local Cliffords and CZ between frontier vertices.
"""
lc = d.local_clifford
for v in sorted(frontier):
# Extract any local Clifford on v
if v in lc.keys():
# to be implemented: add local Clifford gates
pass
# Extract any CZ edges between frontier vertices
for w in list(d.get_neighbors(v) & frontier):
circ.cz(v, w)
d.remove_physical_edge(v, w)


def _update_frontier(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None:
"""
Update the frontier by Gaussian elimination or pivots, then extract new frontier vertices.
"""
# Build bipartite adjacency: frontier vs neighbors
neigh = sorted(set().union(*(d.get_neighbors(v) for v in frontier)))
M = np.zeros((len(frontier), len(neigh)), dtype=int)
for i, v in enumerate(sorted(frontier)):
for j, u in enumerate(neigh):
if u in d.get_neighbors(v):
M[i, j] = 1
# Gaussian eliminate over GF(2)
M_red, row_ops = _gauss_elim(M)
# Identify rows with single 1
vs: List[int] = []
for i, row in enumerate(M_red):
if row.sum() == 1:
col = int(np.nonzero(row)[0][0])
vs.append(neigh[col])

if not vs:
# Step 4: pivot YZ vertices adjacent to frontier
# to be implemented
pass
# for u in list(d.physical_nodes - frontier):
# if d.meas_bases[u].plane.name == 'YZ' and d.get_neighbors(u) & frontier:
# w = next(iter(d.get_neighbors(u) & frontier))
# d.pivot(u, w)
# _process_frontier(d, frontier, circ)
# return

# Apply recorded CNOT row operations
for r1, r2 in row_ops:
circ.cnot(sorted(frontier)[r1], sorted(frontier)[r2]) # CNOT is not implemented in MBQCCircuit
# Update graph accordingly: add edge or local complement as needed
d.apply_cnot(sorted(frontier)[r1], sorted(frontier)[r2])

# Extract new frontier vertices
for v in vs:
# unique neighbor in frontier
w = next(iter(d.get_neighbors(v) & frontier))
circ.add_gate(H(), [w])
circ.add_gate(PhaseGadget(d.meas_bases[v].angle), [w])
frontier.remove(w)
frontier.add(v)
_process_frontier(d, frontier, circ)


def _gauss_elim(M: np.ndarray) -> Tuple[np.ndarray, List[Tuple[int,int]]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please respect PEP8.

"""
Perform Gaussian elimination over GF(2), returning reduced matrix and list of row ops.
"""
M = M.copy() % 2
n, m = M.shape
row_ops: List[Tuple[int,int]] = []
pivot_row = 0
for col in range(m):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too complex.

# find pivot
for r in range(pivot_row, n):
if M[r, col] == 1:
M[[pivot_row, r]] = M[[r, pivot_row]]
if r != pivot_row:
row_ops.append((pivot_row, r))
break
else:
continue
# eliminate other rows
for r in range(n):
if r != pivot_row and M[r, col] == 1:
M[r] ^= M[pivot_row]
row_ops.append((r, pivot_row))
pivot_row += 1
if pivot_row == n:
break
return M, row_ops


def _finalize_extraction(d: ZXGraphState, frontier: Set[int], circ: MBQCCircuit) -> None:
"""
Extract final Hadamards or SWAPs to align frontier to inputs.
"""
# to be implemented

# # Handle any remaining Hadamard on outputs
# for v in sorted(frontier):
# if d.has_hadamard_on_output(v):
# circ.add_gate(H(), [v])
# # Permute frontier to match inputs via SWAPs
# perm = d.compute_permutation(list(frontier), list(d.input_nodes))
# for (q1, q2) in perm:
# circ.add_gate(CNOT(), [q1, q2]) # SWAP as three CNOTs omitted for brevity

def _revert_circuit(d: ZXGraphState, circ: MBQCCircuit) -> None:
"""
Revert the circuit.
"""
# to be implemented
pass
22 changes: 22 additions & 0 deletions graphix_zx/graphstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,28 @@ def local_cliffords(self) -> dict[int, LocalClifford]:
"""
return self.__local_cliffords

@property
def inner2nodes(self) -> dict[int, int]:
"""Return inner index to node index mapping.

Returns
-------
dict[int, int]
inner index to node index mapping.
"""
return self.__inner2nodes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to expose internal states?


@property
def nodes2inner(self) -> dict[int, int]:
"""Return node index to inner index mapping.

Returns
-------
dict[int, int]
node index to inner index mapping.
"""
return self.__nodes2inner

def check_meas_basis(self) -> None:
"""Check if the measurement basis is set for all physical nodes except output nodes.

Expand Down
48 changes: 48 additions & 0 deletions graphix_zx/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np

from graphix_zx.circuit import MBQCCircuit
from graphix_zx.common import default_meas_basis
from graphix_zx.graphstate import GraphState

Expand Down Expand Up @@ -82,3 +83,50 @@ def get_random_flow_graph(
num_nodes += 1

return graph, flow


def get_random_gflow_circ(
width: int,
depth: int,
rng: np.random.Generator | None = None,
edge_p: float = 0.5,
angle_list: list | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use an appropriate collections.abc stuff rather than list .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_list is unnecessary because it's obvious that it should be list from the annotation.

) -> MBQCCircuit:
"""Generate a random MBQC circuit which has gflow.

Parameters
----------
width : int
circuit width
depth : int
circuit depth
rng : np.random.Generator, optional
random number generator, by default np.random.default_rng()
edge_p : float, optional
probability of adding CZ gate, by default 0.5
angle_list : list, optional
list of angles, by default [0, np.pi / 3, 2 * np.pi / 3, np.pi]

Returns
-------
MBQCCircuit
generated MBQC circuit
"""
if rng is None:
rng = np.random.default_rng()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEMO: it's better to cache rng as in graphix, because PRNG is designed to reduce correlations between samples drawn from the same generator instance.

if angle_list is None:
angle_list = [0, np.pi / 3, 2 * np.pi / 3, np.pi]
circ = MBQCCircuit(width)
for d in range(depth):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit complex.

for j in range(width):
circ.j(j, rng.choice(angle_list))
if d < depth - 1:
for j in range(width):
if rng.random() < edge_p:
circ.cz(j, (j + 1) % width)
num = rng.integers(0, width)
if num > 0:
target = set(rng.choice(list(range(width)), num))
circ.phase_gadget(target, rng.choice(angle_list))

return circ
Loading