diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d4c56c7520..5edd1862eb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,7 +10,7 @@ /source/compiler/qsc_partial_eval @idavis @swernli /source/compiler/qsc_rca @idavis @swernli /source/compiler/qsc_rir @idavis @swernli -/source/fuzz @billti @idavis @swernli +/source/fuzz @billti @idavis @swernli /katas @billti @swernli /source/jupyterlab @billti @idavis @minestarks /source/language_service @billti @idavis @minestarks @ScottCarda-MS @@ -19,7 +19,9 @@ /library @swernli @orpuente-MS /source/npm @billti @minestarks @ScottCarda-MS /source/pip @billti @idavis @minestarks +/source/pip/qsharp/qre @msoeken @brad-lackey @jwhogabo /source/playground @billti @minestarks +/source/qre @msoeken @brad-lackey @jwhogabo /source/resource_estimator @billti @swernli /samples @minestarks @swernli /source/vscode @billti @idavis @minestarks diff --git a/Cargo.lock b/Cargo.lock index ba1206f17e..2e0b4ba071 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2074,6 +2074,17 @@ dependencies = [ "wgpu", ] +[[package]] +name = "qre" +version = "0.0.0" +dependencies = [ + "num-traits", + "probability", + "rustc-hash", + "serde", + "thiserror 2.0.18", +] + [[package]] name = "qsc" version = "0.0.0" @@ -2486,6 +2497,7 @@ dependencies = [ "num-traits", "pyo3", "qdk_simulators", + "qre", "qsc", "rand 0.8.5", "rayon", diff --git a/Cargo.toml b/Cargo.toml index b7955f915b..60b4a7d01f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "source/language_service", "source/simulators", "source/pip", + "source/qre", "source/resource_estimator", "source/samples_test", "source/wasm", diff --git a/source/pip/Cargo.toml b/source/pip/Cargo.toml index 70c4f0888e..7d576bbcc9 100644 --- a/source/pip/Cargo.toml +++ b/source/pip/Cargo.toml @@ -18,6 +18,7 @@ num-traits = { workspace = true } qsc = { path = "../compiler/qsc" } qdk_simulators = { path = "../simulators" } resource_estimator = { path = "../resource_estimator" } +qre = { path = "../qre" } miette = { workspace = true, features = ["fancy"] } rustc-hash = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py new file mode 100644 index 0000000000..536669a8aa --- /dev/null +++ b/source/pip/benchmarks/bench_qre.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import timeit +from dataclasses import dataclass, KW_ONLY, field +from qsharp.qre import linear_function, generic_function +from qsharp.qre._architecture import _make_instruction +from qsharp.qre.models import GateBased, SurfaceCode +from qsharp.qre._enumeration import _enumerate_instances + + +def bench_enumerate_instances(): + # Measure performance of enumerating instances with a large domain + @dataclass + class LargeDomain: + _: KW_ONLY + param1: int = field(default=0, metadata={"domain": range(1000)}) + param2: bool + + number = 100 + + duration = timeit.timeit( + "list(_enumerate_instances(LargeDomain))", + globals={ + "_enumerate_instances": _enumerate_instances, + "LargeDomain": LargeDomain, + }, + number=number, + ) + + print(f"Enumerating instances took {duration / number:.6f} seconds on average.") + + +def bench_enumerate_isas(): + import os + import sys + + # Add the tests directory to sys.path to import test_qre + # TODO: Remove this once the models in test_qre are moved to a proper module + sys.path.append(os.path.join(os.path.dirname(__file__), "../tests/qre/")) + from conftest import ExampleLogicalFactory, ExampleFactory # type: ignore + + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Hierarchical factory using from_components + query = SurfaceCode.q() * ExampleLogicalFactory.q( + source=SurfaceCode.q() * ExampleFactory.q() + ) + + number = 100 + duration = timeit.timeit( + "list(query.enumerate(ctx))", + globals={ + "query": query, + "ctx": ctx, + }, + number=number, + ) + + print(f"Enumerating ISAs took {duration / number:.6f} seconds on average.") + + +def bench_function_evaluation_linear(): + fl = linear_function(12) + + inst = _make_instruction(42, 0, None, 1, fl, None, 1.0, {}) + number = 1000 + duration = timeit.timeit( + "inst.space(5)", + globals={ + "inst": inst, + }, + number=number, + ) + + print( + f"Evaluating linear function took {duration / number:.6f} seconds on average." + ) + + +def bench_function_evaluation_generic(): + def func(arity: int) -> int: + return 12 * arity + + fg = generic_function(func) + + inst = _make_instruction(42, 0, None, 1, fg, None, 1.0, {}) + number = 1000 + duration = timeit.timeit( + "inst.space(5)", + globals={ + "inst": inst, + }, + number=number, + ) + + print( + f"Evaluating linear function took {duration / number:.6f} seconds on average." + ) + + +if __name__ == "__main__": + bench_enumerate_instances() + bench_enumerate_isas() + bench_function_evaluation_linear() + bench_function_evaluation_generic() diff --git a/source/pip/qsharp/magnets/geometry/__init__.py b/source/pip/qsharp/magnets/geometry/__init__.py new file mode 100644 index 0000000000..4a7a380f86 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Geometry module for representing quantum system topologies. + +This module provides hypergraph data structures for representing the +geometric structure of quantum systems, including lattice topologies +and interaction graphs. +""" + +from .complete import CompleteBipartiteGraph, CompleteGraph +from .lattice1d import Chain1D, Ring1D +from .lattice2d import Patch2D, Torus2D + +__all__ = [ + "CompleteBipartiteGraph", + "CompleteGraph", + "Chain1D", + "Ring1D", + "Patch2D", + "Torus2D", +] diff --git a/source/pip/qsharp/magnets/geometry/complete.py b/source/pip/qsharp/magnets/geometry/complete.py new file mode 100644 index 0000000000..057abb950b --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/complete.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Complete graph geometries for quantum simulations. + +This module provides classes for representing complete graphs and complete +bipartite graphs as hypergraphs. These structures are useful for quantum +systems with all-to-all or bipartite all-to-all interactions. +""" + +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) + + +class CompleteGraph(Hypergraph): + """A complete graph where every vertex is connected to every other vertex. + + In a complete graph K_n, there are n vertices and n(n-1)/2 edges, + with each pair of distinct vertices connected by exactly one edge. + + Attributes: + n: Number of vertices in the graph. + + Example: + + .. code-block:: python + >>> graph = CompleteGraph(4) + >>> graph.nvertices + 4 + >>> graph.nedges + 6 + """ + + def __init__(self, n: int, self_loops: bool = False) -> None: + """Initialize a complete graph. + + Args: + n: Number of vertices in the graph. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(n)] + else: + _edges = [] + + # Add all pairs of vertices + for i in range(n): + for j in range(i + 1, n): + _edges.append(Hyperedge([i, j])) + super().__init__(_edges) + + self.n = n + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this complete graph.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + if self.n % 2 == 0: + i, j = edge.vertices + m = self.n - 1 + if j == m: + coloring.add_edge(edge, i) + elif (j - i) % 2 == 0: + coloring.add_edge(edge, (j - i) // 2) + else: + coloring.add_edge(edge, (j - i + m) // 2) + else: + m = self.n + i, j = edge.vertices + if (j - i) % 2 == 0: + coloring.add_edge(edge, (j - i) // 2) + else: + coloring.add_edge(edge, (j - i + m) // 2) + return coloring + + +class CompleteBipartiteGraph(Hypergraph): + """A complete bipartite graph with two vertex sets. + + In a complete bipartite graph K_{m,n} (m <= n), there are m + n + vertices partitioned into two sets of sizes m and n. Every vertex + in the first set is connected to every vertex in the second set, + giving m * n edges total. + + Vertices 0 to m-1 form the first set, and vertices m to m+n-1 + form the second set. + + Attributes: + m: Number of vertices in the first set. + n: Number of vertices in the second set. + + Requires: + m <= n + + Example: + + .. code-block:: python + >>> graph = CompleteBipartiteGraph(2, 3) + >>> graph.nvertices + 5 + >>> graph.nedges + 6 + """ + + def __init__(self, m: int, n: int, self_loops: bool = False) -> None: + """Initialize a complete bipartite graph. + + Args: + m: Number of vertices in the first set (vertices 0 to m-1). + n: Number of vertices in the second set (vertices m to m+n-1). + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + assert m <= n, "Require m <= n for CompleteBipartiteGraph." + total_vertices = m + n + + if self_loops: + _edges = [Hyperedge([i]) for i in range(total_vertices)] + + else: + _edges = [] + + # Connect every vertex in first set to every vertex in second set + for i in range(m): + for j in range(m, m + n): + _edges.append(Hyperedge([i, j])) + super().__init__(_edges) + + self.m = m + self.n = n + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this complete bipartite graph.""" + coloring = HypergraphEdgeColoring(self) + m = self.m + n = self.n + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + coloring.add_edge(edge, (i + j - m) % n) + return coloring + + # Color edges based on the second vertex index to create n parallel partitions + # for i in range(m): + # for j in range(m, m + n): + # self.color[(i, j)] = ( + # i + j - m + # ) % n # Color edges based on second vertex index + + # Edge coloring for parallel updates + # The even case: n-1 colors are needed + # if n % 2 == 0: + # m = n - 1 + # for i in range(m): + # self.color[(i, n - 1)] = ( + # i # Connect vertex n-1 to all others with unique colors + # ) + # for j in range(1, (m - 1) // 2 + 1): + # a = (i + j) % m + # b = (i - j) % m + # if a < b: + # self.color[(a, b)] = i + # else: + # self.color[(b, a)] = i + + # The odd case: n colors are needed + # This is the round-robin tournament scheduling algorithm for odd n + # Set m = n for ease of reading + # else: + # m = n + # for i in range(m): + # for j in range(1, (m - 1) // 2 + 1): + # a = (i + j) % m + # b = (i - j) % m + # if a < b: + # self.color[(a, b)] = i + # else: + # self.color[(b, a)] = i diff --git a/source/pip/qsharp/magnets/geometry/lattice1d.py b/source/pip/qsharp/magnets/geometry/lattice1d.py new file mode 100644 index 0000000000..9586167276 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/lattice1d.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""One-dimensional lattice geometries for quantum simulations. + +This module provides classes for representing 1D lattice structures as +hypergraphs. These lattices are commonly used in quantum spin chain +simulations and other one-dimensional quantum systems. +""" + +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) + + +class Chain1D(Hypergraph): + """A one-dimensional open chain lattice. + + Represents a linear chain of vertices with nearest-neighbor edges. + The chain has open boundary conditions, meaning the first and last + vertices are not connected. + + Attributes: + length: Number of vertices in the chain. + + Example: + + .. code-block:: python + >>> chain = Chain1D(4) + >>> chain.nvertices + 4 + >>> chain.nedges + 3 + """ + + def __init__(self, length: int, self_loops: bool = False) -> None: + """Initialize a 1D chain lattice. + + Args: + length: Number of vertices in the chain. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(length)] + + else: + _edges = [] + + for i in range(length - 1): + _edges.append(Hyperedge([i, i + 1])) + + super().__init__(_edges) + self.length = length + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute a valid edge coloring for this chain.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + color = min(i, j) % 2 + coloring.add_edge(edge, color) + return coloring + + +class Ring1D(Hypergraph): + """A one-dimensional ring (periodic chain) lattice. + + Represents a circular chain of vertices with nearest-neighbor edges. + The ring has periodic boundary conditions, meaning the first and last + vertices are connected. + + Attributes: + length: Number of vertices in the ring. + + Example: + + .. code-block:: python + >>> ring = Ring1D(4) + >>> ring.nvertices + 4 + >>> ring.nedges + 4 + """ + + def __init__(self, length: int, self_loops: bool = False) -> None: + """Initialize a 1D ring lattice. + + Args: + length: Number of vertices in the ring. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + if self_loops: + _edges = [Hyperedge([i]) for i in range(length)] + else: + _edges = [] + + for i in range(length): + _edges.append(Hyperedge([i, (i + 1) % length])) + super().__init__(_edges) + + self.length = length + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute a valid edge coloring for this ring.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + else: + i, j = edge.vertices + if {i, j} == {0, self.length - 1}: + color = (self.length % 2) + 1 + else: + color = min(i, j) % 2 + coloring.add_edge(edge, color) + return coloring diff --git a/source/pip/qsharp/magnets/geometry/lattice2d.py b/source/pip/qsharp/magnets/geometry/lattice2d.py new file mode 100644 index 0000000000..a69a8c7644 --- /dev/null +++ b/source/pip/qsharp/magnets/geometry/lattice2d.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Two-dimensional lattice geometries for quantum simulations. + +This module provides classes for representing 2D lattice structures as +hypergraphs. These lattices are commonly used in quantum spin system +simulations and other two-dimensional quantum systems. +""" + +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) + + +class Patch2D(Hypergraph): + """A two-dimensional open rectangular lattice. + + Represents a rectangular grid of vertices with nearest-neighbor edges. + The patch has open boundary conditions, meaning edges do not wrap around. + + Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. + + Attributes: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + + Example: + + .. code-block:: python + >>> patch = Patch2D(3, 2) + >>> str(patch) + '3x2 lattice patch with 6 vertices and 7 edges' + """ + + def __init__(self, width: int, height: int, self_loops: bool = False) -> None: + """Initialize a 2D patch lattice. + + Args: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + self.width = width + self.height = height + + if self_loops: + _edges = [Hyperedge([i]) for i in range(width * height)] + else: + _edges = [] + + # Horizontal edges (connecting (x, y) to (x+1, y)) + for y in range(height): + for x in range(width - 1): + _edges.append(Hyperedge([self._index(x, y), self._index(x + 1, y)])) + + # Vertical edges (connecting (x, y) to (x, y+1)) + for y in range(height - 1): + for x in range(width): + _edges.append(Hyperedge([self._index(x, y), self._index(x, y + 1)])) + super().__init__(_edges) + + def _index(self, x: int, y: int) -> int: + """Convert (x, y) coordinates to vertex index.""" + return y * self.width + x + + def __str__(self) -> str: + """Return the summary string ``"{width}x{height} lattice patch with {nvertices} vertices and {nedges} edges"``.""" + return f"{self.width}x{self.height} lattice patch with {self.nvertices} vertices and {self.nedges} edges" + + def __repr__(self) -> str: + """Return a string representation of the Patch2D geometry.""" + return f"Patch2D(width={self.width}, height={self.height})" + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this 2D patch.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + u, v = edge.vertices + x_u, y_u = u % self.width, u // self.width + x_v, y_v = v % self.width, v // self.width + + if y_u == y_v: + color = 0 if min(x_u, x_v) % 2 == 0 else 1 + else: + color = 2 if min(y_u, y_v) % 2 == 0 else 3 + coloring.add_edge(edge, color) + return coloring + + +class Torus2D(Hypergraph): + """A two-dimensional toroidal (periodic) lattice. + + Represents a rectangular grid of vertices with nearest-neighbor edges + and periodic boundary conditions in both directions. The topology is + that of a torus. + + Vertices are indexed in row-major order: vertex (x, y) has index y * width + x. + + Attributes: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + + Example: + + .. code-block:: python + >>> torus = Torus2D(3, 2) + >>> str(torus) + '3x2 lattice torus with 6 vertices and 12 edges' + """ + + def __init__(self, width: int, height: int, self_loops: bool = False) -> None: + """Initialize a 2D torus lattice. + + Args: + width: Number of vertices in the horizontal direction. + height: Number of vertices in the vertical direction. + self_loops: If True, include self-loop edges on each vertex + for single-site terms. + """ + self.width = width + self.height = height + + if self_loops: + _edges = [Hyperedge([i]) for i in range(width * height)] + else: + _edges = [] + + # Horizontal edges (connecting (x, y) to ((x+1) % width, y)) + for y in range(height): + for x in range(width): + _edges.append( + Hyperedge([self._index(x, y), self._index((x + 1) % width, y)]) + ) + + # Vertical edges (connecting (x, y) to (x, (y+1) % height)) + for y in range(height): + for x in range(width): + _edges.append( + Hyperedge([self._index(x, y), self._index(x, (y + 1) % height)]) + ) + + super().__init__(_edges) + + def _index(self, x: int, y: int) -> int: + """Convert (x, y) coordinates to vertex index.""" + return y * self.width + x + + def __str__(self) -> str: + """Return the summary string ``"{width}x{height} lattice torus with {nvertices} vertices and {nedges} edges"``.""" + return f"{self.width}x{self.height} lattice torus with {self.nvertices} vertices and {self.nedges} edges" + + def __repr__(self) -> str: + """Return a string representation of the Torus2D geometry.""" + return f"Torus2D(width={self.width}, height={self.height})" + + def edge_coloring(self) -> HypergraphEdgeColoring: + """Compute edge coloring for this 2D torus.""" + coloring = HypergraphEdgeColoring(self) + for edge in self.edges(): + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + u, v = edge.vertices + x_u, y_u = u % self.width, u // self.width + x_v, y_v = v % self.width, v // self.width + + if y_u == y_v: + if {x_u, x_v} == {0, self.width - 1}: + color = 1 if self.width % 2 == 0 else 4 + else: + color = 0 if min(x_u, x_v) % 2 == 0 else 1 + else: + if {y_u, y_v} == {0, self.height - 1}: + color = 3 if self.height % 2 == 0 else 5 + else: + color = 2 if min(y_u, y_v) % 2 == 0 else 3 + coloring.add_edge(edge, color) + return coloring diff --git a/source/pip/qsharp/magnets/models/__init__.py b/source/pip/qsharp/magnets/models/__init__.py new file mode 100644 index 0000000000..224270e17e --- /dev/null +++ b/source/pip/qsharp/magnets/models/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Models module for quantum spin models. + +This module provides classes for representing quantum spin models +as Hamiltonians built from Pauli operators. +""" + +from .model import IsingModel, Model + +__all__ = ["Model", "IsingModel"] diff --git a/source/pip/qsharp/magnets/models/model.py b/source/pip/qsharp/magnets/models/model.py new file mode 100755 index 0000000000..d0cf8b1887 --- /dev/null +++ b/source/pip/qsharp/magnets/models/model.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportPrivateImportUsage=false + +from collections.abc import Sequence +from typing import Iterator, Optional + + +"""Base Model class for quantum spin models. + +This module provides the base class for representing quantum spin models +as Hamiltonians. The Model class integrates with hypergraph geometries +to define interaction topologies and stores coefficients for each edge. +""" + +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, + PauliString, +) + + +class Model: + """Base class for quantum spin models. + + This class represents a quantum spin Hamiltonian defined on a hypergraph + geometry. The Hamiltonian is characterized by: + + - Ops: A list of PauliStrings (one entry per interaction term) + - Terms: Groupings of operator indices for Trotterization or parallel execution + + The model is built on a hypergraph geometry that defines which qubits + interact with each other. + + Attributes: + geometry: The Hypergraph defining the interaction topology. + + Example: + + .. code-block:: python + >>> from qsharp.magnets.geometry import Chain1D + >>> geometry = Chain1D(4) + >>> model = Model(geometry) + >>> model.set_coefficient((0, 1), 1.5) + >>> model.set_pauli_string((0, 1), PauliString.from_qubits((0, 1), "ZZ")) + >>> model.get_coefficient((0, 1)) + 1.5 + """ + + def __init__(self, geometry: Hypergraph): + """Initialize the Model. + + Creates a quantum spin model on the given geometry. + + The model stores operators lazily in ``_ops`` as interaction operators + are defined. Noncommuting collections of operators are collected in + ``_terms`` that stores the indices of its interaction operators. This + list of arrays seperate terms into parallizable groups by color. It is + initialized as one empty term group. + + Args: + geometry: Hypergraph defining the interaction topology. The number + of vertices determines the number of qubits in the model. + """ + self.geometry: Hypergraph = geometry + self._qubits: set[int] = set() + self._ops: list[PauliString] = [] + for edge in geometry.edges(): + self._qubits.update(edge.vertices) + self._terms: dict[int, dict[int, list[int]]] = {} + + def add_interaction( + self, + edge: Hyperedge, + pauli_string: Sequence[int | str] | str, + coefficient: complex = 1.0, + term: Optional[int] = None, + color: int = 0, + ) -> None: + """Add an interaction term to the model. + + Args: + edge: The Hyperedge representing the qubits involved in the interaction. + pauli_string: The PauliString operator for this interaction. + coefficient: The complex coefficient multiplying this term (default 1.0). + """ + if edge not in self.geometry.edges(): + raise ValueError("Edge is not part of the model geometry.") + s = PauliString.from_qubits(edge.vertices, pauli_string, coefficient) + self._ops.append(s) + if term is not None: + if term not in self._terms: + self._terms[term] = {} + if color not in self._terms[term]: + self._terms[term][color] = [] + self._terms[term][color].append(len(self._ops) - 1) + + @property + def nqubits(self) -> int: + """Return the number of qubits in the model.""" + return len(self._qubits) + + @property + def nterms(self) -> int: + """Return the number of term groups in the model.""" + return len(self._terms) + + @property + def terms(self) -> list[int]: + """Get the list of term indices in the model.""" + return list(self._terms.keys()) + + def ncolors(self, term: int) -> int: + """Return the number of colors in a given term.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + return len(self._terms[term]) + + def colors(self, term: int) -> list[int]: + """Return the list of colors in a given term.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + return list(self._terms[term].keys()) + + def nops(self, term: int, color: int) -> int: + """Return the number of operators in a given term and color.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + if color not in self._terms[term]: + raise ValueError(f"Color {color} does not exist in term {term}.") + return len(self._terms[term][color]) + + def ops(self, term: int, color: int) -> list[PauliString]: + """Return the list of operators in a given term and color.""" + if term not in self._terms: + raise ValueError(f"Term {term} does not exist in the model.") + if color not in self._terms[term]: + raise ValueError(f"Color {color} does not exist in term {term}.") + return [self._ops[i] for i in self._terms[term][color]] + + def __str__(self) -> str: + """String representation of the model.""" + return "Generic model with {} terms on {} qubits.".format( + len(self._terms), len(self._qubits) + ) + + def __repr__(self) -> str: + """String representation of the model.""" + return self.__str__() + + +class IsingModel(Model): + """Translation-invariant Ising model on a hypergraph geometry. + + The Hamiltonian is: + H = -J * Σ_{} Z_i Z_j - h * Σ_i X_i + + - Single-vertex edges define X-field terms with coefficient ``-h``. + - Two-vertex edges define ZZ-coupling terms with coefficient ``-J``. + - Terms are grouped into two groups: ``0`` for field terms and ``1`` for + coupling terms. + """ + + def __init__(self, geometry: Hypergraph, h: float, J: float): + super().__init__(geometry) + self.h = h + self.J = J + self._terms = {0: {}, 1: {}} + + coloring: HypergraphEdgeColoring = geometry.edge_coloring() + for edge in geometry.edges(): + vertices = edge.vertices + if len(vertices) == 1: + self.add_interaction(edge, "X", -h, term=0, color=0) + elif len(vertices) == 2: + color = coloring.color(edge.vertices) + if color is None: + raise ValueError("Geometry edge coloring failed to assign a color.") + self.add_interaction(edge, "ZZ", -J, term=1, color=color) + + def __str__(self) -> str: + return ( + f"Ising model with {self.nterms} terms on {self.nqubits} qubits " + f"(h={self.h}, J={self.J})." + ) + + def __repr__(self) -> str: + return ( + f"IsingModel(nqubits={self.nqubits}, nterms={self.nterms}, " + f"h={self.h}, J={self.J})" + ) + + +class HeisenbergModel(Model): + """Translation-invariant Heisenberg model on a hypergraph geometry. + + The Hamiltonian is: + H = -J * Σ_{} (X_i X_j + Y_i Y_j + Z_i Z_j) + + - Two-vertex edges define XX, YY, and ZZ coupling terms with coefficient ``-J``. + - Terms are grouped into three parts: ``0`` for XX, ``1`` for YY, and ``2`` for ZZ. + """ + + def __init__(self, geometry: Hypergraph, J: float): + super().__init__(geometry) + self.J = J + self.coloring: HypergraphEdgeColoring = geometry.edge_coloring() + self._terms = {0: {}, 1: {}, 2: {}} + for edge in geometry.edges(): + vertices = edge.vertices + if len(vertices) == 2: + color = self.coloring.color(edge.vertices) + if color is None: + raise ValueError("Geometry edge coloring failed to assign a color.") + self.add_interaction(edge, "XX", -J, term=0, color=color) + self.add_interaction(edge, "YY", -J, term=1, color=color) + self.add_interaction(edge, "ZZ", -J, term=2, color=color) + + def __str__(self) -> str: + return ( + f"Heisenberg model with {self.nterms} terms on {self.nqubits} qubits " + f"(J={self.J})." + ) + + def __repr__(self) -> str: + return ( + f"HeisenbergModel(nqubits={self.nqubits}, nterms={self.nterms}, " + f"J={self.J})" + ) diff --git a/source/pip/qsharp/magnets/trotter/__init__.py b/source/pip/qsharp/magnets/trotter/__init__.py new file mode 100644 index 0000000000..d4beaa68c5 --- /dev/null +++ b/source/pip/qsharp/magnets/trotter/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Trotter-Suzuki methods for time evolution.""" + +from .trotter import ( + TrotterStep, + TrotterExpansion, + strang_splitting, + suzuki_recursion, + yoshida_recursion, + fourth_order_trotter_suzuki, +) + +__all__ = [ + "TrotterStep", + "TrotterExpansion", + "strang_splitting", + "suzuki_recursion", + "yoshida_recursion", + "fourth_order_trotter_suzuki", +] diff --git a/source/pip/qsharp/magnets/trotter/trotter.py b/source/pip/qsharp/magnets/trotter/trotter.py new file mode 100644 index 0000000000..383aca5ae2 --- /dev/null +++ b/source/pip/qsharp/magnets/trotter/trotter.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Trotter schedule utilities for magnet models. + +This module provides: + +- ``TrotterStep``: a schedule of ``(time, term_index)`` entries, +- recursion helpers (Suzuki and Yoshida) that raise the order by 2, +- factory helpers such as Strang splitting, and +- ``TrotterExpansion`` to apply a step repeatedly to a concrete model. +""" + +from collections.abc import Callable +from typing import Iterator, Optional +from qsharp.magnets.models import Model +from qsharp.magnets.utilities import PauliString + +import math + +try: + import cirq +except Exception as ex: + raise ImportError( + "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." + ) from ex + + +class TrotterStep: + """Schedule of Hamiltonian-term applications for one Trotter step. + + A ``TrotterStep`` stores an ordered list of ``(time, term_index)`` tuples. + Each tuple indicates that term group ``term_index`` should be applied for + evolution time ``time``. + + The constructor builds a first-order step over the provided term indices: + + .. math:: + + e^{-i H t} \\approx \\prod_k e^{-i H_k t}, \\quad H = \\sum_k H_k. + + where each supplied term index appears once with duration ``time_step``. + """ + + def __init__(self, terms: list[int] = [], time_step: float = 0.0): + """Initialize a Trotter step from explicit term indices. + + Args: + terms: Ordered term indices to include in this step. + time_step: Duration associated with each listed term. + + Notes: + If ``terms`` is empty, the step is initialized as order 0. + Otherwise, it is initialized as order 1. + """ + self._nterms = len(terms) + self._time_step = time_step + self._order = 1 if self._nterms > 0 else 0 + self._repr_string: Optional[str] = None + self.terms: list[tuple[float, int]] = [(time_step, j) for j in terms] + + @property + def order(self) -> int: + """Get the order of the Trotter decomposition.""" + return self._order + + @property + def nterms(self) -> int: + """Get the number of term entries used to build this schedule.""" + return self._nterms + + @property + def time_step(self) -> float: + """Get the base time step metadata stored on this step.""" + return self._time_step + + def reduce(self) -> None: + """ + Reduce the Trotter step in place by combining consecutive terms that are the same. + + This can be useful for optimizing the Trotter sequence by merging adjacent + applications of the same term into a single application with a longer time step. + + Example: + >>> trotter = TrotterStep() + >>> trotter.terms = [(0.5, 0), (0.5, 0), (0.5, 1)] + >>> trotter.reduce() + >>> list(trotter.step()) + [(1.0, 0), (0.5, 1)] + """ + if len(self.terms) > 1: + reduced_terms: list[tuple[float, int]] = [] + current_time, current_term = self.terms[0] + + for time, term in self.terms[1:]: + if term == current_term: + current_time += time + else: + reduced_terms.append((current_time, current_term)) + current_time, current_term = time, term + + reduced_terms.append((current_time, current_term)) + self.terms = reduced_terms + + def step(self) -> Iterator[tuple[float, int]]: + """Iterate over ``(time, term_index)`` entries for this step.""" + return iter(self.terms) + + def cirq(self, model: Model) -> cirq.Circuit: + """Build a Cirq circuit for one application of this Trotter step. + + Args: + model: Model that maps each term index to grouped Pauli operators. + + Returns: + A ``cirq.Circuit`` containing ``cirq.PauliStringPhasor`` operations + in the same order as ``self.step()``. + """ + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + circuit = cirq.Circuit() + for time, term_index in self.step(): + for color in model.colors(term_index): + for op in model.ops(term_index, color): + pauli = cirq.PauliString( + { + cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] + for p in op._paulis + }, + ) + oper = cirq.PauliStringPhasor(pauli, exponent_neg=time / math.pi) + circuit.append(oper) + return circuit + + def __str__(self) -> str: + """String representation of the Trotter decomposition.""" + return f"Trotter expansion of order {self._order}: time_step={self._time_step}, num_terms={self._nterms}" + + def __repr__(self) -> str: + """String representation of the Trotter decomposition.""" + if self._repr_string is not None: + return self._repr_string + else: + return f"TrotterStep(num_terms={self._nterms}, time_step={self._time_step})" + + +def suzuki_recursion(trotter: TrotterStep) -> TrotterStep: + """ + Apply one level of Suzuki recursion to double the order of a Trotter step. + + Given a k-th order Trotter step S_k(t), this function constructs a (k+2)-nd order + step using the Suzuki fractal decomposition: + + S_{k+2}(t) = S_{k}(p t) S_{k}(p t) S_{k}((1 - 4p) t) S_{k}(p t) S_{k}(p t) + + where p = 1 / (4 - 4^{1/(2k+1)}). + + The resulting step has improved accuracy: the error scales as O(t^{k+3}) instead + of O(t^{k+1}), at the cost of 5x more exponential applications per step. + + Args: + trotter: A TrotterStep of order k to be promoted to order k+2. + + Returns: + A new TrotterStep of order k+2 constructed via Suzuki recursion. + + References: + M. Suzuki, Phys. Lett. A 146, 319 (1990). + """ + + suzuki = TrotterStep() + suzuki._nterms = trotter._nterms + suzuki._time_step = trotter._time_step + suzuki._order = trotter._order + 2 + suzuki._repr_string = f"SuzukiRecursion(order={suzuki._order}, time_step={suzuki._time_step}, num_terms={suzuki._nterms})" + + p = 1 / (4 - 4 ** (1 / (2 * trotter.order + 1))) + + suzuki.terms = [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [ + ((1 - 4 * p) * time, term_index) for time, term_index in trotter.step() + ] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.terms += [(p * time, term_index) for time, term_index in trotter.step()] + suzuki.reduce() # Combine consecutive terms that are the same + + return suzuki + + +def yoshida_recursion(trotter: TrotterStep) -> TrotterStep: + """ + Apply one level of Yoshida recursion to increase the order of a Trotter step by 2. + + Given a k-th order Trotter step S_k(t), this function constructs a (k+2)-nd order + step using Yoshida's symmetric triple-jump composition: + + S_{k+2}(t) = S_{k}(w_1 t) S_{k}(w_0 t) S_{k}(w_1 t) + + where: + w_1 = 1 / (2 - 2^{1/(2k+1)}) + w_0 = -2^{1/(2k+1)} / (2 - 2^{1/(2k+1)}) = 1 - 2 w_1 + + The resulting step has improved accuracy: the error scales as O(t^{k+3}) instead + of O(t^{k+1}), at the cost of 3x more exponential applications per step. + + Args: + trotter: A TrotterStep of order k to be promoted to order k+2. + + Returns: + A new TrotterStep of order k+2 constructed via Yoshida recursion. + + References: + H. Yoshida, Phys. Lett. A 150, 262 (1990). + """ + + yoshida = TrotterStep() + yoshida._nterms = trotter._nterms + yoshida._time_step = trotter._time_step + yoshida._order = trotter._order + 2 + yoshida._repr_string = f"YoshidaRecursion(order={yoshida._order}, time_step={yoshida._time_step}, num_terms={yoshida._nterms})" + + cube_root_2 = 2 ** (1 / (2 * trotter.order + 1)) + w1 = 1 / (2 - cube_root_2) + w0 = 1 - 2 * w1 # equivalent to -cube_root_2 / (2 - cube_root_2) + + yoshida.terms = [(w1 * time, term_index) for time, term_index in trotter.step()] + yoshida.terms += [(w0 * time, term_index) for time, term_index in trotter.step()] + yoshida.terms += [(w1 * time, term_index) for time, term_index in trotter.step()] + yoshida.reduce() # Combine consecutive terms that are the same + + return yoshida + + +def strang_splitting(terms: list[int], time: float) -> TrotterStep: + """ + Create a second-order Strang splitting schedule for explicit term indices. + + The second-order Trotter formula uses symmetric splitting: + + e^{-i H t} \\approx \\prod_{k=1}^{n-1} e^{-i H_k t/2} \\, e^{-i H_n t} \\, \\prod_{k=n-1}^{1} e^{-i H_k t/2} + + This provides second-order accuracy in the time step, compared to + first-order for the basic Trotter decomposition. + + Example: + + .. code-block:: python + >>> strang = strang_splitting(terms=[0, 1, 2], time=0.5) + >>> list(strang.step()) + [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + + Args: + terms: Ordered term indices for a single symmetric step. Must be non-empty. + time: Total evolution time assigned to this second-order step. + + Returns: + A second-order ``TrotterStep``. + + References: + G. Strang, SIAM J. Numer. Anal. 5, 506 (1968). + """ + strang = TrotterStep() + strang._nterms = len(terms) + strang._time_step = time + strang._order = 2 + strang._repr_string = f"StrangSplitting(time_step={time}, num_terms={len(terms)})" + strang.terms = [] + for i in range(len(terms) - 1): + strang.terms.append((time / 2, terms[i])) + strang.terms.append((time, terms[-1])) + for i in reversed(range(len(terms) - 1)): + strang.terms.append((time / 2, terms[i])) + return strang + + +def fourth_order_trotter_suzuki(terms: list[int], time: float) -> TrotterStep: + """ + Factory function for creating a fourth-order Trotter-Suzuki decomposition + using Suzuki recursion. + + This is obtained by applying one level of Suzuki recursion to the second-order + Strang splitting. The resulting fourth-order decomposition has improved accuracy + compared to the second-order Strang splitting, at the cost of more exponential + applications per step. + + Example: + + .. code-block:: python + >>> fourth_order = fourth_order_trotter_suzuki(terms=[0, 1, 2], time=0.5) + >>> list(fourth_order.step()) + [(0.1767766952966369, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.3535533905932738, 1), (0.3535533905932738, 0), (0.1767766952966369, 1), (0.1767766952966369, 2), (0.1767766952966369, 1), (0.1767766952966369, 0)] + """ + return suzuki_recursion(strang_splitting(terms, time)) + + +class TrotterExpansion: + """Repeated application of a Trotter method on a concrete model. + + ``TrotterExpansion`` builds one step with ``trotter_method(model.terms, dt)`` + where ``dt = time / num_steps`` and then repeats it ``num_steps`` times. + + Iteration via :meth:`step` yields ``PauliString`` operators already scaled by + the per-entry schedule time. + """ + + def __init__( + self, + trotter_method: Callable[[list[int], float], TrotterStep], + model: Model, + time: float, + num_steps: int, + ): + """Initialize a repeated-step Trotter expansion. + + Args: + trotter_method: Callable mapping ``(terms, dt)`` to a ``TrotterStep``. + model: Model that defines term groups and per-term operators. + time: Total evolution time. + num_steps: Number of repeated Trotter steps. + """ + self._model = model + self._num_steps = num_steps + self._trotter_step = trotter_method(model.terms, time / num_steps) + + @property + def order(self) -> int: + """Get the order of the underlying Trotter step.""" + return self._trotter_step.order + + @property + def nterms(self) -> int: + """Get the number of Hamiltonian terms.""" + return self._model.nterms + + @property + def nsteps(self) -> int: + """Get the number of Trotter steps.""" + return self._num_steps + + @property + def total_time(self) -> float: + """Get the total evolution time (time_step * num_steps).""" + return self._trotter_step.time_step * self._num_steps + + def step(self) -> Iterator[PauliString]: + """Iterate over scaled operators for the full expansion. + + Yields: + ``PauliString`` operators with coefficients scaled by schedule time, + in execution order across all repeated steps. + """ + for _ in range(self._num_steps): + for s, i in self._trotter_step.step(): + for c in self._model.colors(i): + for op in self._model.ops(i, c): + yield (op * s) + + def cirq(self) -> cirq.CircuitOperation: + """Get a repeated Cirq circuit operation for this expansion.""" + circuit = self._trotter_step.cirq(self._model).freeze() + return cirq.CircuitOperation(circuit, repetitions=self._num_steps) + + def __str__(self) -> str: + """String representation of the Trotter expansion.""" + return ( + f"TrotterExpansion(order={self.order}, num_steps={self._num_steps}, " + f"total_time={self.total_time}, num_terms={self.nterms})" + ) + + def __repr__(self) -> str: + """Repr representation of the Trotter expansion.""" + return f"TrotterExpansion({self._trotter_step!r}, num_steps={self._num_steps})" diff --git a/source/pip/qsharp/magnets/utilities/__init__.py b/source/pip/qsharp/magnets/utilities/__init__.py new file mode 100644 index 0000000000..b350f7da40 --- /dev/null +++ b/source/pip/qsharp/magnets/utilities/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utilities module for magnets package. + +This module provides utility data structures and algorithms used across +the magnets package, including hypergraph representations. +""" + +from .hypergraph import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) +from .pauli import Pauli, PauliString, PauliX, PauliY, PauliZ + +__all__ = [ + "Hyperedge", + "Hypergraph", + "HypergraphEdgeColoring", + "Pauli", + "PauliString", + "PauliX", + "PauliY", + "PauliZ", +] diff --git a/source/pip/qsharp/magnets/utilities/hypergraph.py b/source/pip/qsharp/magnets/utilities/hypergraph.py new file mode 100644 index 0000000000..b7caffbd99 --- /dev/null +++ b/source/pip/qsharp/magnets/utilities/hypergraph.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Hypergraph data structures for representing quantum system geometries. + +This module provides classes for representing hypergraphs, which generalize +graphs by allowing edges (hyperedges) to connect any number of vertices. +Hypergraphs are useful for representing interaction terms in quantum +Hamiltonians, where multi-body interactions can involve more than two sites. +""" + +import random +from typing import Iterator, Optional + + +class Hyperedge: + """A hyperedge connecting one or more vertices in a hypergraph. + + A hyperedge generalizes the concept of an edge in a graph. While a + traditional edge connects exactly two vertices, a hyperedge can connect + any number of vertices. This is useful for representing: + - Single-site terms (self-loops): 1 vertex + - Two-body interactions: 2 vertices + - Multi-body interactions: 3+ vertices + Each hyperedge is defined by a set of unique vertex indices, which are + stored as a sorted tuple for consistency and hashability. + + Attributes: + vertices: Sorted tuple of vertex indices connected by this hyperedge. + + Example: + + .. code-block:: python + >>> edge = Hyperedge([2, 0, 1]) + >>> edge.vertices + (0, 1, 2) + """ + + def __init__(self, vertices: list[int]) -> None: + """Initialize a hyperedge with the given vertices. + + Args: + vertices: List of vertex indices. Will be sorted internally. + """ + self.vertices: tuple[int, ...] = tuple(sorted(set(vertices))) + + def __str__(self) -> str: + return str(self.vertices) + + def __repr__(self) -> str: + return f"Hyperedge({list(self.vertices)})" + + +class Hypergraph: + """A hypergraph consisting of vertices connected by hyperedges. + + A hypergraph is a generalization of a graph where edges (hyperedges) can + connect any number of vertices. This class serves as the base class for + various lattice geometries used in quantum simulations. + + Attributes: + _edge_set: Set of hyperedges in the hypergraph. + _vertex_set: Set of all unique vertex indices in the hypergraph. + + Note: + Edge colors are managed separately by :class:`HypergraphEdgeColoring`. + Use :meth:`edge_coloring` to generate a coloring for this hypergraph. + + Example: + + .. code-block:: python + >>> edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + >>> graph = Hypergraph(edges) + >>> graph.nvertices + 3 + >>> graph.nedges + 3 + """ + + def __init__(self, edges: list[Hyperedge]) -> None: + """Initialize a hypergraph with the given edges. + + Args: + edges: List of hyperedges defining the hypergraph structure. + """ + self._vertex_set = set() + self._edge_set = set(edges) + for edge in edges: + self._vertex_set.update(edge.vertices) + + @property + def nvertices(self) -> int: + """Return the number of vertices in the hypergraph.""" + return len(self._vertex_set) + + def vertices(self) -> Iterator[int]: + """Iterate over all vertex indices in the hypergraph. + + Returns: + Iterator of vertex indices in ascending order. + """ + return iter(sorted(self._vertex_set)) + + @property + def nedges(self) -> int: + """Return the number of hyperedges in the hypergraph.""" + return len(self._edge_set) + + def edges(self) -> Iterator[Hyperedge]: + """Iterate over all hyperedges in the hypergraph. + + Returns: + Iterator of all hyperedges in the hypergraph. + """ + return iter(self._edge_set) + + def add_edge(self, edge: Hyperedge) -> None: + """Add a hyperedge to the hypergraph. + + Args: + edge: The Hyperedge instance to add. + """ + self._edge_set.add(edge) + self._vertex_set.update(edge.vertices) + + def edge_coloring( + self, seed: Optional[int] = 0, trials: int = 1 + ) -> "HypergraphEdgeColoring": + """Compute a (nondeterministic) greedy edge coloring of this hypergraph. + + Args: + seed: Optional random seed for reproducibility. + trials: Number of randomized trials to attempt. The best coloring + (fewest colors) is returned. + + Returns: + A :class:`HypergraphEdgeColoring` for this hypergraph. + """ + all_edges = sorted(self.edges(), key=lambda edge: edge.vertices) + + if not all_edges: + return HypergraphEdgeColoring(self) + + num_trials = max(trials, 1) + best_coloring: Optional[HypergraphEdgeColoring] = None + least_colors: Optional[int] = None + + for trial in range(num_trials): + trial_seed = None if seed is None else seed + trial + rng = random.Random(trial_seed) + + edge_order = list(all_edges) + rng.shuffle(edge_order) + + coloring = HypergraphEdgeColoring(self) + num_colors = 0 + + for edge in edge_order: + if len(edge.vertices) == 1: + coloring.add_edge(edge, -1) + continue + + assigned = False + for color in range(num_colors): + used_vertices = set().union( + *( + candidate.vertices + for candidate in coloring.edges_of_color(color) + ) + ) + if not any(vertex in used_vertices for vertex in edge.vertices): + coloring.add_edge(edge, color) + assigned = True + break + + if not assigned: + coloring.add_edge(edge, num_colors) + num_colors += 1 + + if least_colors is None or coloring.ncolors < least_colors: + least_colors = coloring.ncolors + best_coloring = coloring + + assert best_coloring is not None + return best_coloring + + def __str__(self) -> str: + return f"Hypergraph with {self.nvertices} vertices and {self.nedges} edges." + + def __repr__(self) -> str: + return f"Hypergraph({list(self._edge_set)})" + + +class HypergraphEdgeColoring: + """Edge-color assignment for a :class:`Hypergraph`. + + This class stores colors separately from :class:`Hypergraph` and enforces + the rule that multi-vertex edges sharing a color do not share any vertices. + + Conventions: + + - Colors for nontrivial edges must be nonnegative integers. + - Single-vertex edges may use a special color (for example ``-1``). + - Only nonnegative colors contribute to :attr:`ncolors`. + + Note: + Colors are keyed by edge vertex tuples (``edge.vertices``), not by + ``Hyperedge`` object identity. As a result, :meth:`color` accepts edge + vertex tuples directly, while :meth:`add_edge` still requires an edge + instance that belongs to :attr:`hypergraph`. + + Attributes: + hypergraph: The supporting :class:`Hypergraph` whose edges can be + colored by this instance. + """ + + def __init__(self, hypergraph: Hypergraph) -> None: + self.hypergraph = hypergraph + self._colors: dict[tuple[int, ...], int] = {} # Vertices-to-color mapping + self._used_vertices: dict[int, set[int]] = ( + {} + ) # Set of vertices used by each color + + @property + def ncolors(self) -> int: + """Return the number of distinct nonnegative colors in the coloring.""" + return len(self._used_vertices) + + def color(self, vertices: tuple[int, ...]) -> Optional[int]: + """Return the color assigned to edge vertices. + + Args: + vertices: Canonical vertex tuple for the edge to query (typically + ``edge.vertices``). + + Returns: + The color assigned to ``vertices``, or ``None`` if the edge has + not been added to this coloring. + """ + if not isinstance(vertices, tuple) or not all( + isinstance(vertex, int) for vertex in vertices + ): + raise TypeError("vertices must be tuple[int, ...]") + return self._colors.get(vertices) + + def colors(self) -> Iterator[int]: + """Iterate over distinct nonnegative colors present in the coloring. + + Returns: + Iterator of distinct nonnegative color indices. + """ + return iter(self._used_vertices.keys()) + + def add_edge(self, edge: Hyperedge, color: int) -> None: + """Add ``edge`` to this coloring with the specified ``color``. + + For multi-vertex edges, this enforces that no previously added edge + with the same color shares a vertex with ``edge``. + + Args: + edge: The Hyperedge instance to add. This must be an edge present + in :attr:`hypergraph` (typically one returned by + ``hypergraph.edges()``). + color: Color index for the edge. + + Raises: + TypeError: If ``edge`` is not a :class:`Hyperedge`. + ValueError: If ``edge`` is not part of :attr:`hypergraph`. + ValueError: If ``color`` is negative for a nontrivial edge. + RuntimeError: If adding ``edge`` would create a same-color vertex + conflict. + """ + if not isinstance(edge, Hyperedge): + raise TypeError(f"edge must be Hyperedge, got {type(edge).__name__}") + + if edge not in self.hypergraph.edges(): + raise ValueError("edge must belong to the supporting Hypergraph") + + vertices = edge.vertices + + if len(vertices) == 1: + # Single-vertex edges can be colored with a special color (e.g., -1) + self._colors[vertices] = color + else: + if color < 0: + raise ValueError( + "Color index must be nonnegative for multi-vertex edges." + ) + if color not in self._used_vertices: + self._colors[vertices] = color + self._used_vertices[color] = set(vertices) + else: + if any(v in self._used_vertices[color] for v in vertices): + raise RuntimeError( + "Edge conflicts with existing edge of same color." + ) + self._colors[vertices] = color + self._used_vertices[color].update(vertices) + + self._colors[vertices] = color + + def edges_of_color(self, color: int) -> Iterator[Hyperedge]: + """Iterate over hyperedges with a specific color. + + Args: + color: Color index for filtering edges. + + Returns: + Iterator of edges currently assigned to ``color``. + """ + return iter( + [ + edge + for edge in self.hypergraph.edges() + if self._colors.get(edge.vertices) == color + ] + ) diff --git a/source/pip/qsharp/magnets/utilities/pauli.py b/source/pip/qsharp/magnets/utilities/pauli.py new file mode 100644 index 0000000000..4eb7b92e5b --- /dev/null +++ b/source/pip/qsharp/magnets/utilities/pauli.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pauli operator representation for quantum spin systems.""" + +from collections.abc import Sequence + +try: + import cirq +except Exception as ex: + raise ImportError( + "qsharp.magnets.models requires the cirq extras. Install with 'pip install \"qsharp[cirq]\"'." + ) from ex + + +class Pauli: + """Single-qubit Pauli term tied to an explicit qubit index. + + ``Pauli`` stores a Pauli identifier and the qubit it acts on. The Pauli + identifier can be provided either as an integer code or a label: + + - ``0`` / ``"I"`` + - ``1`` / ``"X"`` + - ``2`` / ``"Z"`` + - ``3`` / ``"Y"`` + + Note: + The integer mapping follows the internal QDK convention where ``2`` is + ``Z`` and ``3`` is ``Y``. + + Example: + + .. code-block:: python + >>> p = Pauli("Y", qubit=2) + >>> p.op + 3 + >>> p.qubit + 2 + """ + + _VALID_INTS = {0, 1, 2, 3} + _STR_TO_INT = {"I": 0, "X": 1, "Z": 2, "Y": 3} + + def __init__(self, value: int | str, qubit: int = 0) -> None: + """Initialize a Pauli operator. + + Args: + value: An integer 0-3 or one of 'I', 'X', 'Y', 'Z' (case-insensitive). + qubit: The index of the qubit this operator acts on. Defaults to 0. + + Raises: + ValueError: If ``value`` is not a valid integer/string Pauli identifier. + """ + if isinstance(value, int): + if value not in self._VALID_INTS: + raise ValueError(f"Integer value must be 0-3, got {value}.") + self._op = value + elif isinstance(value, str): + key = value.upper() + if key not in self._STR_TO_INT: + raise ValueError( + f"String value must be one of 'I', 'X', 'Y', 'Z', got '{value}'." + ) + self._op = self._STR_TO_INT[key] + else: + raise ValueError(f"Expected int or str, got {type(value).__name__}.") + self.qubit: int = qubit + + @property + def op(self) -> int: + """Integer encoding of this Pauli term. + + Returns: + ``0`` for ``I``, ``1`` for ``X``, ``2`` for ``Z``, ``3`` for ``Y``. + """ + return self._op + + def __str__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return f"{labels[self._op]}({self.qubit})" + + def __repr__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return f"Pauli('{labels[self._op]}', qubit={self.qubit})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Pauli): + return NotImplemented + return self._op == other._op and self.qubit == other.qubit + + def __hash__(self) -> int: + return hash((self._op, self.qubit)) + + @property + def cirq(self): + """Return this Pauli term as a Cirq gate operation on ``LineQubit``. + + Returns: + A Cirq operation equivalent to + ``cirq.{I|X|Z|Y}.on(cirq.LineQubit(self.qubit))``. + """ + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + return _INT_TO_CIRQ[self._op].on(cirq.LineQubit(self.qubit)) + + +def PauliX(qubit: int) -> Pauli: + """Create a Pauli-X operator on the given qubit.""" + return Pauli("X", qubit) + + +def PauliY(qubit: int) -> Pauli: + """Create a Pauli-Y operator on the given qubit.""" + return Pauli("Y", qubit) + + +def PauliZ(qubit: int) -> Pauli: + """Create a Pauli-Z operator on the given qubit.""" + return Pauli("Z", qubit) + + +class PauliString: + """Ordered tensor product of single-qubit ``Pauli`` terms with a coefficient. + + ``PauliString`` stores: + + - an ordered tuple of :class:`Pauli` objects (including each term's qubit), and + - a complex scalar coefficient. + + Construction options: + + - pass a sequence of :class:`Pauli` objects to ``PauliString(...)`` + - use :meth:`from_qubits` to pair qubit indices with Pauli labels/codes + + Example: + + .. code-block:: python + >>> ps = PauliString([PauliX(0), PauliZ(1)], coefficient=-1j) + >>> ps.qubits + (0, 1) + >>> ps2 = PauliString.from_qubits((0, 1), "XZ", coefficient=-1j) + >>> ps == ps2 + True + """ + + def __init__(self, paulis: Sequence[Pauli], coefficient: complex = 1.0) -> None: + """Initialize a PauliString from a sequence of Pauli operators. + + Args: + paulis: A sequence of :class:`Pauli` instances, each with its + own qubit index. + coefficient: Complex coefficient multiplying the Pauli string (default 1.0). + + Raises: + TypeError: If any element is not a Pauli instance. + """ + for p in paulis: + if not isinstance(p, Pauli): + raise TypeError( + f"Expected Pauli instance, got {type(p).__name__}. " + "Use PauliString.from_qubits() for int/str values." + ) + self._paulis: tuple[Pauli, ...] = tuple(paulis) + self._coefficient: complex = coefficient + + @classmethod + def from_qubits( + cls, + qubits: tuple[int, ...], + values: Sequence[int | str] | str, + coefficient: complex = 1.0, + ) -> "PauliString": + """Create a PauliString from qubit indices and Pauli labels. + + Args: + qubits: Tuple of qubit indices. + values: Sequence of Pauli identifiers (integers 0-3 or strings + 'I', 'X', 'Y', 'Z'). A plain string like ``"XZI"`` is also + accepted and treated as individual characters. + coefficient: Complex coefficient multiplying the Pauli string. + + Returns: + A new PauliString instance. + + Raises: + ValueError: If qubits and values have different lengths, or if + any value is not a valid Pauli identifier. + """ + if len(qubits) != len(values): + raise ValueError( + f"Length mismatch: {len(qubits)} qubits vs {len(values)} values." + ) + paulis = [Pauli(v, q) for q, v in zip(qubits, values)] + return cls(paulis, coefficient=coefficient) + + @property + def qubits(self) -> tuple[int, ...]: + """Tuple of qubit indices in the same order as the stored Pauli terms. + + Returns: + Tuple of qubit indices, one per Pauli operator. + """ + return tuple(p.qubit for p in self._paulis) + + @property + def coefficient(self) -> complex: + """Complex coefficient multiplying this Pauli string.""" + return self._coefficient + + @property + def paulis(self) -> str: + """String of Pauli labels in the same order as the stored Pauli terms. + + Returns: + String of Pauli labels ('I', 'X', 'Z', 'Y'), one per Pauli operator. + """ + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + return "".join(labels[p.op] for p in self._paulis) + + def __iter__(self): + """Iterate over Pauli terms in stored order. + + Yields: + :class:`Pauli` instances in order. + """ + return iter(self._paulis) + + def __len__(self) -> int: + return len(self._paulis) + + def __getitem__(self, index: int) -> Pauli: + return self._paulis[index] + + def __mul__(self, scalar: complex) -> "PauliString": + """Scale the coefficient of this PauliString by a complex scalar.""" + return PauliString(self._paulis, coefficient=self._coefficient * scalar) + + def __str__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + s = "".join(map(str, self._paulis)) + return f"{self._coefficient} * {s}" + + def __repr__(self) -> str: + labels = {0: "I", 1: "X", 2: "Z", 3: "Y"} + s = "".join(labels[p.op] for p in self._paulis) + return f"PauliString(qubits={self.qubits}, ops='{s}', coefficient={self._coefficient})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PauliString): + return NotImplemented + return self._paulis == other._paulis and self._coefficient == other._coefficient + + def __hash__(self) -> int: + return hash((self._paulis, self._coefficient)) + + @property + def cirq(self): + """Return the corresponding Cirq ``PauliString``. + + Constructs a ``cirq.PauliString`` by applying each single-qubit + Pauli to its corresponding ``cirq.LineQubit``. + + Returns: + A ``cirq.PauliString`` on ``cirq.LineQubit`` instances with + ``self._coefficient`` as its coefficient. + """ + _INT_TO_CIRQ = (cirq.I, cirq.X, cirq.Z, cirq.Y) + return cirq.PauliString( + {cirq.LineQubit(p.qubit): _INT_TO_CIRQ[p.op] for p in self._paulis}, + coefficient=self._coefficient, + ) diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py new file mode 100644 index 0000000000..6ba945acf1 --- /dev/null +++ b/source/pip/qsharp/qre/__init__.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._application import Application +from ._architecture import Architecture +from ._estimation import estimate +from ._instruction import ( + LOGICAL, + PHYSICAL, + Encoding, + ISATransform, + constraint, + InstructionSource, +) +from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT +from ._qre import ( + ISA, + InstructionFrontier, + Constraint, + ConstraintBound, + EstimationResult, + FactoryResult, + ISARequirements, + Block, + Trace, + block_linear_function, + constant_function, + generic_function, + linear_function, + instruction_name, + property_name, + property_name_to_key, +) +from ._results import ( + EstimationTable, + EstimationTableColumn, + EstimationTableEntry, + plot_estimates, +) +from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform + +# Extend Rust Python types with additional Python-side functionality +from ._instruction import _isa_as_frame, _requirements_as_frame + +ISA.as_frame = _isa_as_frame +ISARequirements.as_frame = _requirements_as_frame + +__all__ = [ + "block_linear_function", + "constant_function", + "constraint", + "estimate", + "linear_function", + "plot_estimates", + "Application", + "Architecture", + "Block", + "Constraint", + "ConstraintBound", + "Encoding", + "EstimationResult", + "EstimationTable", + "EstimationTableColumn", + "EstimationTableEntry", + "FactoryResult", + "generic_function", + "instruction_name", + "InstructionFrontier", + "InstructionSource", + "ISA", + "ISA_ROOT", + "ISAQuery", + "ISARefNode", + "ISARequirements", + "ISATransform", + "LatticeSurgery", + "PSSPC", + "property_name", + "property_name_to_key", + "Trace", + "TraceQuery", + "TraceTransform", + "LOGICAL", + "PHYSICAL", +] diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py new file mode 100644 index 0000000000..6c20621b2b --- /dev/null +++ b/source/pip/qsharp/qre/_application.py @@ -0,0 +1,172 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import types +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from types import NoneType +from typing import ( + ClassVar, + Generic, + Protocol, + TypeVar, + Generator, + get_type_hints, + cast, +) + +from ._enumeration import _enumerate_instances +from ._qre import Trace, EstimationResult +from ._trace import TraceQuery + + +class DataclassProtocol(Protocol): + __dataclass_fields__: ClassVar[dict] + + +TraceParameters = TypeVar("TraceParameters", DataclassProtocol, types.NoneType) + + +class Application(ABC, Generic[TraceParameters]): + """ + An application defines a class of quantum computation problems along with a + method to generate traces for specific problem instances. + + We distinguish between application and trace parameters. The application + parameters define which particular instance of the application we want to + consider. The trace parameters define how to generate a trace. They change + the specific way in which we solve the problem, but not the problem itself. + + For example, in quantum cryptanalysis, the application parameters could + define the key size for an RSA prime product, while the trace parameters + define which algorithm to use to break the cryptography, as well as + parameters therein. + """ + + _parallel_traces: bool = True + + @abstractmethod + def get_trace(self, parameters: TraceParameters) -> Trace: + """Return the trace corresponding to this application and parameters. + + Args: + parameters (TraceParameters): The trace parameters. + + Returns: + Trace: The trace for this application instance and parameters. + """ + + @staticmethod + def q(**kwargs) -> TraceQuery: + """Create a trace query for this application. + + Args: + **kwargs: Domain overrides forwarded to trace parameter enumeration. + + Returns: + TraceQuery: A trace query for this application type. + """ + return TraceQuery(NoneType, **kwargs) + + def context(self) -> _Context: + """Create a new enumeration context for this application.""" + return _Context(self) + + def post_process( + self, parameters: TraceParameters, estimation: EstimationResult + ) -> EstimationResult: + """Post-process an estimation result for a given set of trace parameters.""" + return estimation + + def enumerate_traces( + self, + **kwargs, + ) -> Generator[Trace, None, None]: + """Yield all traces of an application given its dataclass parameters. + + Args: + **kwargs: Domain overrides forwarded to ``_enumerate_instances``. + + Yields: + Trace: A trace for each enumerated set of trace parameters. + """ + + param_type = get_type_hints(self.__class__.get_trace).get("parameters") + if param_type is types.NoneType: + yield self.get_trace(None) # type: ignore + return + + if isinstance(param_type, TypeVar): + for c in param_type.__constraints__: + if c is not types.NoneType: + param_type = c + break + + if self._parallel_traces: + instances = list(_enumerate_instances(cast(type, param_type), **kwargs)) + with ThreadPoolExecutor() as executor: + for trace in executor.map(self.get_trace, instances): + yield trace + else: + for instances in _enumerate_instances(cast(type, param_type), **kwargs): + yield self.get_trace(instances) + + def enumerate_traces_with_parameters( + self, + **kwargs, + ) -> Generator[tuple[TraceParameters, Trace], None, None]: + """Yield (parameters, trace) pairs for an application. + + Like ``enumerate_traces``, but each yielded trace is accompanied by the + trace parameters that were used to generate it. + + Args: + **kwargs: Domain overrides forwarded to ``_enumerate_instances``. + + Yields: + tuple[TraceParameters, Trace]: A pair of trace parameters and + the corresponding trace. + """ + + param_type = get_type_hints(self.__class__.get_trace).get("parameters") + if param_type is types.NoneType: + yield None, self.get_trace(None) # type: ignore + return + + if isinstance(param_type, TypeVar): + for c in param_type.__constraints__: + if c is not types.NoneType: + param_type = c + break + + if self._parallel_traces: + instances = list(_enumerate_instances(cast(type, param_type), **kwargs)) + with ThreadPoolExecutor() as executor: + for instance, trace in zip( + instances, executor.map(self.get_trace, instances) + ): + yield instance, trace + else: + for instance in _enumerate_instances(cast(type, param_type), **kwargs): + yield instance, self.get_trace(instance) + + def disable_parallel_traces(self): + """Disable parallel trace generation for this application.""" + self._parallel_traces = False + + +class _Context: + """Enumeration context wrapping an application instance.""" + + application: Application + + def __init__(self, application: Application, **kwargs): + """Initialize the context for the given application. + + Args: + application (Application): The application instance. + **kwargs: Additional keyword arguments (reserved for future use). + """ + self.application = application diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py new file mode 100644 index 0000000000..cd8bb52e64 --- /dev/null +++ b/source/pip/qsharp/qre/_architecture.py @@ -0,0 +1,244 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +import copy +from typing import cast, TYPE_CHECKING + +from abc import ABC, abstractmethod + +from ._qre import ( + ISA, + _ProvenanceGraph, + Instruction, + _IntFunction, + _FloatFunction, + constant_function, + property_name_to_key, +) + +if TYPE_CHECKING: + from typing import Optional + + from ._instruction import ISATransform, Encoding + + +class Architecture(ABC): + """Abstract base class for quantum hardware architectures.""" + + @abstractmethod + def provided_isa(self, ctx: ISAContext) -> ISA: + """ + Create the ISA provided by this architecture, adding instructions + directly to the context's provenance graph. + + Args: + ctx (ISAContext): The enumeration context whose provenance graph stores + the instructions. + + Returns: + ISA: The ISA backed by the context's provenance graph. + """ + ... + + def context(self) -> ISAContext: + """Create a new enumeration context for this architecture. + + Returns: + ISAContext: A new enumeration context. + """ + return ISAContext(self) + + +class ISAContext: + """ + Context passed through enumeration, holding shared state. + """ + + def __init__(self, arch: Architecture): + """Initialize the ISA context for the given architecture. + + Args: + arch (Architecture): The architecture providing the base ISA. + """ + self._provenance: _ProvenanceGraph = _ProvenanceGraph() + + # Let the architecture create instructions directly in the graph. + self._isa = arch.provided_isa(self) + + self._bindings: dict[str, ISA] = {} + self._transforms: dict[int, Architecture | ISATransform] = {0: arch} + + def _with_binding(self, name: str, isa: ISA) -> ISAContext: + """Return a new context with an additional binding (internal use).""" + ctx = copy.copy(self) + ctx._bindings = {**self._bindings, name: isa} + return ctx + + @property + def isa(self) -> ISA: + """The ISA provided by the architecture for this context.""" + return self._isa + + def add_instruction( + self, + id_or_instruction: int | Instruction, + encoding: Encoding = 0, # type: ignore + *, + arity: Optional[int] = 1, + time: int | _IntFunction = 0, + space: Optional[int] | _IntFunction = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = 0.0, + transform: ISATransform | None = None, + source: list[Instruction] | None = None, + **kwargs: int, + ) -> int: + """ + Create an instruction and add it to the provenance graph. + + Can be called in two ways: + + 1. With keyword args to create a new instruction:: + + ctx.add_instruction(T, encoding=LOGICAL, time=1000, + error_rate=1e-8) + + 2. With a pre-existing ``Instruction`` object (e.g. from + ``with_id()``):: + + ctx.add_instruction(existing_instruction) + + Provenance is recorded when *transform* and/or *source* are + supplied: + + - **transform** — the ``ISATransform`` that produced the + instruction. + - **source** — input instructions consumed by the transform. + + Args: + id_or_instruction: Either an instruction ID (int) for creating + a new instruction, or an existing ``Instruction`` object. + encoding: The instruction encoding (0 = Physical, 1 = Logical). + Ignored when passing an existing ``Instruction``. + arity: The instruction arity. ``None`` for variable arity. + Ignored when passing an existing ``Instruction``. + time: Instruction time in ns (or ``_IntFunction`` for variable + arity). Ignored when passing an existing ``Instruction``. + space: Instruction space in physical qubits (or ``_IntFunction`` + for variable arity). Ignored when passing an existing + ``Instruction``. + length: Arity including ancilla qubits. Ignored when passing an + existing ``Instruction``. + error_rate: Instruction error rate (or ``_FloatFunction`` for + variable arity). Ignored when passing an existing + ``Instruction``. + transform: The ``ISATransform`` that produced the instruction. + source: List of source ``Instruction`` objects consumed by the + transform. + **kwargs: Additional properties (e.g. ``distance=9``). Ignored + when passing an existing ``Instruction``. + + Returns: + The node index in the provenance graph. + + Raises: + ValueError: If an unknown property name is provided in kwargs. + """ + if transform is None and source is None: + return self._provenance.add_instruction( + cast(int, id_or_instruction), + encoding, + arity=arity, + time=time, + space=space, + length=length, + error_rate=error_rate, + **kwargs, + ) + + if isinstance(id_or_instruction, Instruction): + instr = id_or_instruction + else: + instr = _make_instruction( + id_or_instruction, + int(encoding), + arity, + time, + space, + length, + error_rate, + kwargs, + ) + + transform_id = id(transform) if transform is not None else 0 + children = [inst.source for inst in source] if source else [] + + node_index = self._provenance.add_node(instr, transform_id, children) + + if transform is not None: + self._transforms[transform_id] = transform + + return node_index + + def make_isa(self, *node_indices: int) -> ISA: + """ + Create an ISA backed by this context's provenance graph from the + given node indices. + + Args: + *node_indices (int): Node indices in the provenance graph. + + Returns: + ISA: An ISA referencing the provenance graph. + """ + return self._provenance.make_isa(list(node_indices)) + + +def _make_instruction( + id: int, + encoding: int, + arity: int | None, + time: int | _IntFunction, + space: int | _IntFunction | None, + length: int | _IntFunction | None, + error_rate: float | _FloatFunction, + properties: dict[str, int], +) -> Instruction: + """Build an ``Instruction`` from keyword arguments.""" + if arity is not None: + instr = Instruction.fixed_arity( + id, + encoding, + arity, + cast(int, time), + cast(int | None, space), + cast(int | None, length), + cast(float, error_rate), + ) + else: + if isinstance(time, int): + time = constant_function(time) + if isinstance(space, int): + space = constant_function(space) + if isinstance(length, int): + length = constant_function(length) + if isinstance(error_rate, (int, float)): + error_rate = constant_function(float(error_rate)) + + instr = Instruction.variable_arity( + id, + encoding, + time, + cast(_IntFunction, space), + error_rate, + length, + ) + + for key, value in properties.items(): + prop_key = property_name_to_key(key) + if prop_key is None: + raise ValueError(f"Unknown property '{key}'.") + instr.set_property(prop_key, value) + + return instr diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py new file mode 100644 index 0000000000..b01d706944 --- /dev/null +++ b/source/pip/qsharp/qre/_enumeration.py @@ -0,0 +1,242 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import types +from typing import ( + Generator, + Type, + TypeVar, + Literal, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) +from dataclasses import MISSING +from itertools import product +from enum import Enum + + +T = TypeVar("T") + + +def _is_union_type(tp) -> bool: + """Check if a type is a Union or Python 3.10+ union (X | Y).""" + return get_origin(tp) is Union or isinstance(tp, types.UnionType) + + +def _is_type_filter(val, union_members: tuple) -> bool: + """ + Check if *val* is a union member type or a list of union member types, + i.e. a type filter for a union field (as opposed to a fixed value or + instance domain). + """ + member_set = set(union_members) + if isinstance(val, type) and val in member_set: + return True + if isinstance(val, list) and all( + isinstance(v, type) and v in member_set for v in val + ): + return True + return False + + +def _is_union_constraint_dict(val) -> bool: + """ + Check if *val* is a dict whose keys are all types, i.e. a per-member + constraint mapping for a union field. + + Example: ``{OptionA: {"number": [2, 3]}, OptionB: {}}`` + """ + return isinstance(val, dict) and all(isinstance(k, type) for k in val) + + +def _enumerate_union_members( + union_members: tuple, + val=None, +) -> list: + """ + Enumerate instances for a union-typed field. + + *val* controls which members are enumerated and how: + + - ``None`` - enumerate all members with their default domains. + - A single type (e.g. ``OptionB``) - enumerate only that member. + - A list of types (e.g. ``[OptionA, OptionB]``) - enumerate those members. + - A dict mapping types to constraint dicts + (e.g. ``{OptionA: {"number": [2, 3]}, OptionB: {}}``) - + enumerate only the listed members, forwarding the constraint dicts. + """ + # No override - enumerate all members with defaults + if val is None: + domain: list = [] + for member_type in union_members: + domain.extend(_enumerate_instances(member_type)) + return domain + + # Single type + if isinstance(val, type): + return list(_enumerate_instances(val)) + + # List of types + if isinstance(val, list) and all(isinstance(v, type) for v in val): + domain = [] + for member_type in val: + domain.extend(_enumerate_instances(member_type)) + return domain + + # Dict of type → constraint dict + if _is_union_constraint_dict(val): + domain = [] + for member_type, member_kwargs in cast(dict, val).items(): + domain.extend(_enumerate_instances(member_type, **member_kwargs)) + return domain + + raise ValueError( + f"Invalid value for union field: {val!r}. " + "Expected a union member type, a list of types, or a dict mapping " + "types to constraint dicts." + ) + + +def _enumerate_instances(cls: Type[T], **kwargs) -> Generator[T, None, None]: + """ + Yield all instances of a dataclass given its class. + + The enumeration logic supports defining domains for fields using the + ``domain`` metadata key. Additionally, boolean fields are automatically + enumerated with ``[True, False]``, Enum fields with all their members, + and Literal types with their defined values. + + **Nested dataclass fields** can be constrained by passing a dict:: + + _enumerate_instances(Outer, inner={"option": True}) + + **Union-typed fields** support several override forms: + + - A single type to select one member:: + + _enumerate_instances(Config, option=OptionB) + + - A list of types to select a subset:: + + _enumerate_instances(Config, option=[OptionA, OptionB]) + + - A dict mapping types to constraint dicts:: + + _enumerate_instances(Config, option={OptionA: {"number": [2, 3]}, OptionB: {}}) + + Args: + cls (Type[T]): The dataclass type to enumerate. + **kwargs: Fixed values or domains for fields. If a value is a list + and the corresponding field is kw_only, it is treated as a domain + to enumerate over. For nested dataclass fields a ``dict`` value + is forwarded as keyword arguments. For union-typed fields a type, + list of types, or ``dict[type, dict]`` controls member selection + and constraints. + + Returns: + Generator[T, None, None]: A generator yielding instances of the + dataclass. + + Raises: + ValueError: If a field cannot be enumerated (no domain found). + """ + + names = [] + values = [] + fixed_kwargs = {} + + if (fields := getattr(cls, "__dataclass_fields__", None)) is None: + # There are no fields defined for this class, so just yield a single + # instance + yield cls(**kwargs) + return + + # Resolve type hints to handle stringified types from __future__.annotations + type_hints = get_type_hints(cls) + + for field in fields.values(): # type: ignore + name = field.name + # Get resolved type or fallback to field.type + current_type = type_hints.get(name, field.type) + + if name in kwargs: + val = kwargs[name] + + is_union = _is_union_type(current_type) + union_members = get_args(current_type) if is_union else () + + # Union field with a type filter or constraint dict + if is_union and ( + _is_type_filter(val, union_members) or _is_union_constraint_dict(val) + ): + names.append(name) + values.append(_enumerate_union_members(union_members, val)) + continue + + # Nested dataclass field with a dict of constraints + if ( + isinstance(val, dict) + and not is_union + and isinstance(current_type, type) + and hasattr(current_type, "__dataclass_fields__") + ): + names.append(name) + values.append(list(_enumerate_instances(current_type, **val))) + continue + + # If kw_only and list, it's a domain to enumerate + if field.kw_only and isinstance(val, list): + names.append(name) + values.append(val) + else: + # Otherwise, it's a fixed value + fixed_kwargs[name] = val + continue + + if not field.kw_only: + # We don't enumerate non-kw-only fields that aren't in kwargs + continue + + # Derived domain logic + names.append(name) + + domain = field.metadata.get("domain", None) + if domain is not None: + values.append(domain) + continue + + if current_type is bool: + values.append([True, False]) + continue + + if isinstance(current_type, type) and issubclass(current_type, Enum): + values.append(list(current_type)) + continue + + if get_origin(current_type) is Literal: + values.append(list(get_args(current_type))) + continue + + # Union types (e.g., OptionA | OptionB or Union[OptionA, OptionB]) + if _is_union_type(current_type): + values.append(_enumerate_union_members(get_args(current_type), None)) + continue + + # Nested dataclass types + if isinstance(current_type, type) and hasattr( + current_type, "__dataclass_fields__" + ): + values.append(list(_enumerate_instances(current_type))) + continue + + if field.default is not MISSING: + values.append([field.default]) + continue + + raise ValueError(f"Cannot enumerate field {name}.") + + for instance_values in product(*values): + yield cls(**fixed_kwargs, **dict(zip(names, instance_values))) diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py new file mode 100644 index 0000000000..174f0cee90 --- /dev/null +++ b/source/pip/qsharp/qre/_estimation.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import cast, Optional, Any + + +from ._application import Application +from ._architecture import Architecture +from ._qre import ( + _estimate_parallel, + _estimate_with_graph, + _EstimationCollection, + Trace, +) +from ._trace import TraceQuery, PSSPC, LatticeSurgery +from ._isa_enumeration import ISAQuery +from ._results import EstimationTable, EstimationTableEntry + + +def estimate( + application: Application, + architecture: Architecture, + isa_query: ISAQuery, + trace_query: Optional[TraceQuery] = None, + *, + max_error: float = 1.0, + post_process: bool = False, + use_graph: bool = True, + name: Optional[str] = None, +) -> EstimationTable: + """ + Estimate the resource requirements for a given application instance and + architecture. + + The application instance might return multiple traces. Each of the traces + is transformed by the trace query, which applies several trace transforms in + sequence. Each transform may return multiple traces. Similarly, the + architecture's ISA is transformed by the ISA query, which applies several + ISA transforms in sequence, each of which may return multiple ISAs. The + estimation is performed for each combination of transformed trace and ISA. + The results are collected into an EstimationTable and returned. + + The collection only contains the results that are optimal with respect to + the total number of qubits and the total runtime. + + Note: + The pruning strategy used when ``use_graph`` is set to True (default) + filters ISA instructions by comparing their per-instruction space, time, + and error independently. However, the total qubit count of a result + depends on the interaction between factory space and runtime: + ``factory_qubits = copies × factory_space`` where copies are determined + by ``count.div_ceil(runtime / factory_time)``. Because of this, an ISA + instruction that is dominated on per-instruction metrics can still + contribute to a globally Pareto-optimal result (e.g., a factory with + higher time may need fewer copies, leading to fewer total qubits). As a + consequence, ``use_graph=True`` may miss some results that + ``use_graph=False`` would find. Use ``use_graph=False`` when completeness of + the Pareto frontier is required. + + Args: + application (Application): The quantum application to be estimated. + architecture (Architecture): The target quantum architecture. + isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. + trace_query (TraceQuery): The trace query to enumerate traces from the + application. + max_error (float): The maximum allowed error for the estimation results. + post_process (bool): If True, use the Python-threaded estimation path + (intended for future post-processing logic). If False (default), + use the Rust parallel estimation path. + use_graph (bool): If True (default), use the Rust estimation path that + builds a graph of ISAs and prunes suboptimal ISAs during estimation. + If False, use the Rust estimation path that does not perform any + pruning and simply enumerates all ISAs for each trace. + name (Optional[str]): An optional name for the estimation. If given, this + will be added as a first column to the results table for all entries. + + Returns: + EstimationTable: A table containing the optimal estimation results. + """ + + app_ctx = application.context() + arch_ctx = architecture.context() + + if trace_query is None: + trace_query = PSSPC.q() * LatticeSurgery.q() + + if post_process: + # Enumerate traces with their parameters so we can post-process later + params_and_traces = cast( + list[tuple[Any, Trace]], + list(trace_query.enumerate(app_ctx, track_parameters=True)), + ) + num_traces = len(params_and_traces) + + # Phase 1: Run all estimates in Rust (parallel, fast). + traces_only = [trace for _, trace in params_and_traces] + + if use_graph: + isa_query.populate(arch_ctx) + arch_ctx._provenance.build_pareto_index() + + num_isas = arch_ctx._provenance.total_isa_count() + + collection = _estimate_with_graph( + cast(list[Trace], traces_only), arch_ctx._provenance, max_error, True + ) + isas = collection.isas + else: + isas = list(isa_query.enumerate(arch_ctx)) + + num_isas = len(isas) + + collection = _estimate_parallel( + cast(list[Trace], traces_only), isas, max_error, True + ) + + total_jobs = collection.total_jobs + successful = collection.successful_estimates + summaries = collection.all_summaries # (trace_idx, isa_idx, qubits, runtime) + + # Phase 2: Learn per-trace runtime multiplier and qubit multiplier from + # one sample each: if post_process changes runtime or qubit count it + # will affect the Pareto optimality, but the changes depend only on the + # trace, not on the ISA. + trace_multipliers: dict[int, tuple[float, float]] = {} + trace_sample_isa: dict[int, int] = {} + for t_idx, isa_idx, _q, r in summaries: + if t_idx not in trace_sample_isa: + trace_sample_isa[t_idx] = isa_idx + for t_idx, isa_idx in trace_sample_isa.items(): + params, trace = params_and_traces[t_idx] + sample = trace.estimate(isas[isa_idx], max_error) + if sample is not None: + pre_q = sample.qubits + pre_r = sample.runtime + pp = app_ctx.application.post_process(params, sample) + if pp is not None and pre_r > 0 and pre_q > 0: + trace_multipliers[t_idx] = (pp.qubits / pre_q, pp.runtime / pre_r) + + # Phase 3: Estimate post-pp values and filter to Pareto candidates. + estimated_pp: list[tuple[int, int, int, int]] = ( + [] + ) # (t_idx, isa_idx, est_q, est_r) + for t_idx, isa_idx, q, r in summaries: + mult_q, mult_r = trace_multipliers.get(t_idx, (0.0, 0.0)) + est_q = int(q * mult_q) if mult_q > 0 else q + est_r = int(r * mult_r) if mult_r > 0 else r + estimated_pp.append((t_idx, isa_idx, est_q, est_r)) + + # Build approximate post-pp Pareto frontier to identify candidates. + estimated_pp.sort(key=lambda x: (x[2], x[3])) # sort by qubits, then runtime + approx_pareto: list[tuple[int, int, int, int]] = [] + min_r = float("inf") + for item in estimated_pp: + if item[3] < min_r: + approx_pareto.append(item) + min_r = item[3] + + # Phase 4: Re-estimate and post-process only the Pareto candidates. + pp_collection = _EstimationCollection() + for t_idx, isa_idx, _q, _r in approx_pareto: + params, trace = params_and_traces[t_idx] + result = trace.estimate(isas[isa_idx], max_error) + if result is not None: + pp_result = app_ctx.application.post_process(params, result) + if pp_result is not None: + pp_collection.insert(pp_result) + collection = pp_collection + else: + traces = list(trace_query.enumerate(app_ctx)) + num_traces = len(traces) + + if use_graph: + isa_query.populate(arch_ctx) + arch_ctx._provenance.build_pareto_index() + + num_isas = arch_ctx._provenance.total_isa_count() + + collection = _estimate_with_graph( + cast(list[Trace], traces), arch_ctx._provenance, max_error, False + ) + else: + isas = list(isa_query.enumerate(arch_ctx)) + + num_isas = len(isas) + + # Use the Rust parallel estimation path + collection = _estimate_parallel( + cast(list[Trace], traces), isas, max_error, False + ) + + total_jobs = collection.total_jobs + successful = collection.successful_estimates + + # Post-process the results and add them to a results table + table = EstimationTable() + + table.name = name + + if name is not None: + table.insert_column(0, "name", lambda entry: name) + + table.extend( + EstimationTableEntry.from_result(result, arch_ctx) for result in collection + ) + + # Fill in the stats for this estimation run + table.stats.num_traces = num_traces + table.stats.num_isas = num_isas + table.stats.total_jobs = total_jobs + table.stats.successful_estimates = successful + table.stats.pareto_results = len(collection) + + return table diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py new file mode 100644 index 0000000000..e48bcecd43 --- /dev/null +++ b/source/pip/qsharp/qre/_instruction.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Generator, Iterable, Optional +from enum import IntEnum + +import pandas as pd + +from ._architecture import ISAContext, Architecture +from ._enumeration import _enumerate_instances +from ._isa_enumeration import ( + ISA_ROOT, + _BindingNode, + _ComponentQuery, + ISAQuery, +) +from ._qre import ( + ISA, + Constraint, + ConstraintBound, + Instruction, + ISARequirements, + instruction_name, + property_name_to_key, +) + + +class Encoding(IntEnum): + PHYSICAL = 0 + LOGICAL = 1 + + +PHYSICAL = Encoding.PHYSICAL +LOGICAL = Encoding.LOGICAL + + +def constraint( + id: int, + encoding: Encoding = PHYSICAL, + *, + arity: Optional[int] = 1, + error_rate: Optional[ConstraintBound] = None, + **kwargs: bool, +) -> Constraint: + """ + Create an instruction constraint. + + Args: + id (int): The instruction ID. + encoding (Encoding): The instruction encoding. PHYSICAL (0) or LOGICAL (1). + arity (Optional[int]): The instruction arity. If None, instruction is + assumed to have variable arity. Default is 1. + error_rate (Optional[ConstraintBound]): The constraint on the error rate. + **kwargs (bool): Required properties that matching instructions must have. + Valid property names: distance. Set to True to require the property. + + Returns: + Constraint: The instruction constraint. + + Raises: + ValueError: If an unknown property name is provided in kwargs. + """ + c = Constraint(id, encoding, arity, error_rate) + + for key, value in kwargs.items(): + if value: + if (prop_key := property_name_to_key(key)) is None: + raise ValueError(f"Unknown property '{key}'") + + c.add_property(prop_key) + + return c + + +class ISATransform(ABC): + """ + Abstract base class for transformations between ISAs (e.g., QEC schemes). + + An ISA transform defines a mapping from a required input ISA (e.g., + architecture constraints) to a provided output ISA (logical instructions). + It supports enumeration of configuration parameters. + """ + + @staticmethod + @abstractmethod + def required_isa() -> ISARequirements: + """ + Return the requirements that an implementation ISA must satisfy. + + Returns: + ISARequirements: The requirements for the underlying ISA. + """ + ... + + @abstractmethod + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + """ + Yields ISAs provided by this transform given an implementation ISA. + + Args: + impl_isa (ISA): The implementation ISA that satisfies requirements. + ctx (ISAContext): The enumeration context whose provenance graph + stores the instructions. + + Yields: + ISA: A provided logical ISA. + """ + ... + + @classmethod + def enumerate_isas( + cls, + impl_isa: ISA | Iterable[ISA], + ctx: ISAContext, + **kwargs, + ) -> Generator[ISA, None, None]: + """ + Enumerate all valid ISAs for this transform given implementation ISAs. + + This method iterates over all instances of the transform class (enumerating + hyperparameters) and filters implementation ISAs against requirements. + + Args: + impl_isa (ISA | Iterable[ISA]): One or more implementation ISAs. + ctx (ISAContext): The enumeration context. + **kwargs: Arguments passed to parameter enumeration. + + Yields: + ISA: Valid provided ISAs. + """ + isas = [impl_isa] if isinstance(impl_isa, ISA) else impl_isa + for isa in isas: + if not isa.satisfies(cls.required_isa()): + continue + + for component in _enumerate_instances(cls, **kwargs): + ctx._transforms[id(component)] = component + yield from component.provided_isa(isa, ctx) + + @classmethod + def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: + """ + Create an ISAQuery node for this transform. + + Args: + source (Node | None): The source node providing implementation ISAs. + Defaults to ISA_ROOT. + **kwargs: Additional arguments for parameter enumeration. + + Returns: + ISAQuery: An enumeration node representing this transform. + """ + return _ComponentQuery( + cls, source=source if source is not None else ISA_ROOT, kwargs=kwargs + ) + + @classmethod + def bind(cls, name: str, node: ISAQuery) -> _BindingNode: + """ + Create a BindingNode for this transform. + + This is a convenience method equivalent to ``cls.q().bind(name, node)``. + + Args: + name (str): The name to bind the transform's output to. + node (Node): The child node that can reference this binding. + + Returns: + BindingNode: A binding node enclosing this transform. + """ + return cls.q().bind(name, node) + + +@dataclass(slots=True) +class InstructionSource: + nodes: list[_InstructionSourceNode] = field(default_factory=list, init=False) + roots: list[int] = field(default_factory=list, init=False) + + @classmethod + def from_isa(cls, ctx: ISAContext, isa: ISA) -> InstructionSource: + """ + Construct an InstructionSource graph from an ISA. + + The instruction source graph contains more information than the + provenance graph in the context, as it connects the instructions to the + transforms and architectures that generated them. + + Args: + ctx (ISAContext): The enumeration context containing the provenance graph. + isa (ISA): Instructions in the ISA will serve as root nodes in the source graph. + + Returns: + InstructionSource: The instruction source graph for the estimation result. + """ + + def _make_node( + graph: InstructionSource, source_table: dict[int, int], source: int + ) -> int: + if source in source_table: + return source_table[source] + + children = [ + _make_node(graph, source_table, child) + for child in ctx._provenance.children(source) + if child != 0 + ] + + node = graph.add_node( + ctx._provenance.instruction(source), + ctx._transforms.get(ctx._provenance.transform_id(source)), + children, + ) + + source_table[source] = node + return node + + graph = cls() + source_table: dict[int, int] = {} + + for inst in isa: + node_idx = isa.node_index(inst.id) + if node_idx is not None and node_idx != 0: + node = _make_node(graph, source_table, node_idx) + graph.add_root(node) + + return graph + + def add_root(self, node_id: int) -> None: + """Add a root node to the instruction source graph. + + Args: + node_id (int): The index of the node to add as a root. + """ + self.roots.append(node_id) + + def add_node( + self, + instruction: Instruction, + transform: Optional[ISATransform | Architecture], + children: list[int], + ) -> int: + """Add a node to the instruction source graph. + + Args: + instruction (Instruction): The instruction for this node. + transform (Optional[ISATransform | Architecture]): The transform + that produced the instruction. + children (list[int]): Indices of child nodes. + + Returns: + int: The index of the newly added node. + """ + node_id = len(self.nodes) + self.nodes.append(_InstructionSourceNode(instruction, transform, children)) + return node_id + + def __str__(self) -> str: + """Return a formatted string representation of the instruction source graph.""" + + def _format_node(node: _InstructionSourceNode, indent: int = 0) -> str: + result = " " * indent + f"{instruction_name(node.instruction.id) or '??'}" + if node.transform is not None: + result += f" @ {node.transform}" + for child_index in node.children: + result += "\n" + _format_node(self.nodes[child_index], indent + 2) + return result + + return "\n".join( + _format_node(self.nodes[root_index]) for root_index in self.roots + ) + + def __getitem__(self, id: int) -> _InstructionSourceNodeReference: + """ + Retrieve the first instruction source root node with the given + instruction ID. Raises KeyError if no such node exists. + + Args: + id (int): The instruction ID to search for. + + Returns: + _InstructionSourceNodeReference: The first instruction source node with the + given instruction ID. + """ + if (node := self.get(id)) is not None: + return node + + raise KeyError(f"Instruction ID {id} not found in instruction source graph.") + + def __contains__(self, id: int) -> bool: + """ + Check if there is an instruction source root node with the given + instruction ID. + + Args: + id (int): The instruction ID to search for. + + Returns: + bool: True if a node with the given instruction ID exists, False otherwise. + """ + for root in self.roots: + if self.nodes[root].instruction.id == id: + return True + + return False + + def get( + self, id: int, default: Optional[_InstructionSourceNodeReference] = None + ) -> Optional[_InstructionSourceNodeReference]: + """ + Retrieve the first instruction source root node with the given + instruction ID. Returns default if no such node exists. + + Args: + id (int): The instruction ID to search for. + default (Optional[_InstructionSourceNodeReference]): The value to return if no + node with the given ID is found. Default is None. + + Returns: + Optional[_InstructionSourceNodeReference]: The first instruction source node with the + given instruction ID, or default if no such node exists. + """ + for root in self.roots: + if self.nodes[root].instruction.id == id: + return _InstructionSourceNodeReference(self, root) + + return default + + +@dataclass(frozen=True, slots=True) +class _InstructionSourceNode: + """A node in the instruction source graph.""" + + instruction: Instruction + transform: Optional[ISATransform | Architecture] + children: list[int] + + +class _InstructionSourceNodeReference: + """Reference to a node in an InstructionSource graph.""" + + def __init__(self, graph: InstructionSource, node_id: int): + """Initialize a reference to a node in the instruction source graph. + + Args: + graph (InstructionSource): The owning instruction source graph. + node_id (int): The index of the referenced node. + """ + self.graph = graph + self.node_id = node_id + + @property + def instruction(self) -> Instruction: + """The instruction at this node.""" + return self.graph.nodes[self.node_id].instruction + + @property + def transform(self) -> Optional[ISATransform | Architecture]: + """The transform that produced this node's instruction, if any.""" + return self.graph.nodes[self.node_id].transform + + def __str__(self) -> str: + """Return a string representation of the referenced node.""" + return str(self.graph.nodes[self.node_id]) + + def __getitem__(self, id: int) -> _InstructionSourceNodeReference: + """ + Retrieve the first child instruction source node with the given + instruction ID. Raises KeyError if no such node exists. + + Args: + id (int): The instruction ID to search for. + + Returns: + _InstructionSourceNodeReference: The first child instruction source node with the + given instruction ID. + """ + if (node := self.get(id)) is not None: + return node + + raise KeyError( + f"Instruction ID {id} not found in children of instruction {instruction_name(self.instruction.id) or '??'}." + ) + + def get( + self, id: int, default: Optional[_InstructionSourceNodeReference] = None + ) -> Optional[_InstructionSourceNodeReference]: + """ + Retrieve the first child instruction source node with the given + instruction ID. Returns default if no such node exists. + + Args: + id (int): The instruction ID to search for. + default (Optional[_InstructionSourceNodeReference]): The value to return if no + node with the given ID is found. Default is None. + + Returns: + Optional[_InstructionSourceNodeReference]: The first child instruction source + node with the given instruction ID, or default if no such node + exists. + """ + + for child_id in self.graph.nodes[self.node_id].children: + if self.graph.nodes[child_id].instruction.id == id: + return _InstructionSourceNodeReference(self.graph, child_id) + + return default + + +def _isa_as_frame(self: ISA) -> pd.DataFrame: + """Convert an ISA to a pandas DataFrame. + + Args: + self (ISA): The ISA to convert. + + Returns: + pd.DataFrame: A DataFrame with columns for id, encoding, arity, + space, time, and error. + """ + data = { + "id": [instruction_name(inst.id) for inst in self], + "encoding": [Encoding(inst.encoding).name for inst in self], + "arity": [inst.arity for inst in self], + "space": [ + inst.expect_space() if inst.arity is not None else None for inst in self + ], + "time": [ + inst.expect_time() if inst.arity is not None else None for inst in self + ], + "error": [ + inst.expect_error_rate() if inst.arity is not None else None + for inst in self + ], + } + + df = pd.DataFrame(data) + df.set_index("id", inplace=True) + return df + + +def _requirements_as_frame(self: ISARequirements) -> pd.DataFrame: + """Convert ISA requirements to a pandas DataFrame. + + Args: + self (ISARequirements): The requirements to convert. + + Returns: + pd.DataFrame: A DataFrame with columns for id, encoding, and arity. + """ + data = { + "id": [instruction_name(inst.id) for inst in self], + "encoding": [Encoding(inst.encoding).name for inst in self], + "arity": [inst.arity for inst in self], + } + + df = pd.DataFrame(data) + df.set_index("id", inplace=True) + return df diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py new file mode 100644 index 0000000000..7543c071ed --- /dev/null +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -0,0 +1,428 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import functools +import itertools +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Generator + +from ._architecture import ISAContext +from ._enumeration import _enumerate_instances +from ._qre import ISA + + +class ISAQuery(ABC): + """ + Abstract base class for all nodes in the ISA enumeration tree. + + Enumeration nodes define the structure of the search space for ISAs starting + from architectures and modified by ISA transforms such as error correction + schemes. They can be composed using operators like ``+`` (sum) and ``*`` + (product) to build complex enumeration strategies. + """ + + @abstractmethod + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields all ISA instances represented by this enumeration node. + + Args: + ctx (ISAContext): The enumeration context containing shared state, + e.g., access to the underlying architecture. + + Yields: + ISA: A possible ISA that can be generated from this node. + """ + pass + + def populate(self, ctx: ISAContext) -> int: + """ + Populate the provenance graph with instructions from this node. + + Unlike ``enumerate``, this does not yield ISA objects. Each transform + queries the graph for Pareto-optimal instructions matching its + requirements, and adds produced instructions directly to the graph. + + Args: + ctx (ISAContext): The enumeration context whose provenance graph + will be populated. + + Returns: + int: The starting node index of the instructions contributed by + this subtree. Used by consumers to scope graph queries to only + see their source's nodes. + """ + # Default implementation: consume enumerate for its side effects + start = ctx._provenance.raw_node_count() + for _ in self.enumerate(ctx): + pass + return start + + def __add__(self, other: ISAQuery) -> _SumNode: + """ + Perform a union of two enumeration nodes. + + Enumerating the sum node yields all ISAs from this node, followed by all + ISAs from the other node. Duplicate ISAs may be produced if both nodes + yield the same ISA. + + Args: + other (Node): The other enumeration node. + + Returns: + SumNode: A node representing the union of both enumerations. + + Example: + + The following enumerates ISAs from both SurfaceCode and ColorCode: + + .. code-block:: python + for isa in SurfaceCode.q() + ColorCode.q(): + ... + """ + if isinstance(self, _SumNode) and isinstance(other, _SumNode): + sources = self.sources + other.sources + return _SumNode(sources) + elif isinstance(self, _SumNode): + sources = self.sources + [other] + return _SumNode(sources) + elif isinstance(other, _SumNode): + sources = [self] + other.sources + return _SumNode(sources) + else: + return _SumNode([self, other]) + + def __mul__(self, other: ISAQuery) -> _ProductNode: + """ + Perform the cross product of two enumeration nodes. + + Enumerating the product node yields ISAs resulting from the Cartesian + product of ISAs from both nodes. The ISAs are combined using + concatenation (logical union). This means that instructions in the + other enumeration node with the same ID as an instruction in this + enumeration node will overwrite the instruction from this node. + + Args: + other (Node): The other enumeration node. + + Returns: + ProductNode: A node representing the product of both enumerations. + + Example: + + The following enumerates ISAs formed by combining ISAs from a + surface code and a factory: + + .. code-block:: python + + for isa in SurfaceCode.q() * Factory.q(): + ... + """ + if isinstance(self, _ProductNode) and isinstance(other, _ProductNode): + sources = self.sources + other.sources + return _ProductNode(sources) + elif isinstance(self, _ProductNode): + sources = self.sources + [other] + return _ProductNode(sources) + elif isinstance(other, _ProductNode): + sources = [self] + other.sources + return _ProductNode(sources) + else: + return _ProductNode([self, other]) + + def bind(self, name: str, node: ISAQuery) -> "_BindingNode": + """Create a BindingNode with this node as the component. + + Args: + name: The name to bind the component to. + node: The child enumeration node that may contain ISARefNodes. + + Returns: + A BindingNode with self as the component. + + Example: + + .. code-block:: python + ExampleErrorCorrection.q().bind("c", ISARefNode("c") * ISARefNode("c")) + """ + return _BindingNode(name=name, component=self, node=node) + + +@dataclass +class RootNode(ISAQuery): + """ + Represents the architecture's base ISA. + Reads from the context instead of holding a reference. + """ + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields the architecture ISA from the context. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: The architecture's provided ISA, called root. + """ + yield ctx._isa + + def populate(self, ctx: ISAContext) -> int: + """Architecture ISA is already in the graph from ``ISAContext.__init__``. + + Returns: + int: 1, since architecture nodes start at index 1. + """ + return 1 + + +# Singleton instance for convenience +ISA_ROOT = RootNode() + + +@dataclass +class _ComponentQuery(ISAQuery): + """ + Query node that enumerates ISAs based on a component type and source. + + This node takes a component type (which must have an ``enumerate_isas`` class + method) and a source node. It enumerates the source node to get base ISAs, + and then calls ``enumerate_isas`` on the component type for each base ISA + to generate derived ISAs. + + Attributes: + component: The component type to query (e.g., a QEC code class). + source: The source node providing input ISAs (default: ISA_ROOT). + kwargs: Additional keyword arguments passed to ``enumerate_isas``. + """ + + component: type + source: ISAQuery = field(default_factory=lambda: ISA_ROOT) + kwargs: dict = field(default_factory=dict) + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields ISAs generated by the component from source ISAs. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: A generated ISA instance. + """ + for isa in self.source.enumerate(ctx): + yield from self.component.enumerate_isas(isa, ctx, **self.kwargs) + + def populate(self, ctx: ISAContext) -> int: + """ + Populate the graph by querying matching instructions. + + Runs the source first to ensure dependency instructions are in + the graph, then queries the graph for all instructions matching + this component's requirements within the source's node range. + For each matching ISA × each hyperparameter instance, calls + ``provided_isa`` to add new instructions to the graph. + + Returns: + int: The starting node index of this component's own additions. + """ + source_start = self.source.populate(ctx) + impl_isas = ctx._provenance.query_satisfying( + self.component.required_isa(), min_node_idx=source_start + ) + own_start = ctx._provenance.raw_node_count() + for instance in _enumerate_instances(self.component, **self.kwargs): + ctx._transforms[id(instance)] = instance + for impl_isa in impl_isas: + for _ in instance.provided_isa(impl_isa, ctx): + pass + return own_start + + +@dataclass +class _ProductNode(ISAQuery): + """ + Node representing the Cartesian product of multiple source nodes. + + Attributes: + sources: A list of source nodes to combine. + """ + + sources: list[ISAQuery] + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields ISAs formed by combining ISAs from all source nodes. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: A combined ISA instance. + """ + source_generators = [source.enumerate(ctx) for source in self.sources] + yield from ( + functools.reduce(lambda a, b: a + b, isa_tuple) + for isa_tuple in itertools.product(*source_generators) + ) + + def populate(self, ctx: ISAContext) -> int: + """Populate the graph from each source sequentially (no cross product). + + Returns: + int: The starting node index before any source populated. + """ + first = ctx._provenance.raw_node_count() + for source in self.sources: + source.populate(ctx) + return first + + +@dataclass +class _SumNode(ISAQuery): + """ + Node representing the union of multiple source nodes. + + Attributes: + sources: A list of source nodes to enumerate sequentially. + """ + + sources: list[ISAQuery] + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields ISAs from each source node in sequence. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: An ISA instance from one of the sources. + """ + for source in self.sources: + yield from source.enumerate(ctx) + + def populate(self, ctx: ISAContext) -> int: + """Populate the graph from each source sequentially. + + Returns: + int: The starting node index before any source populated. + """ + first = ctx._provenance.raw_node_count() + for source in self.sources: + source.populate(ctx) + return first + + +@dataclass +class ISARefNode(ISAQuery): + """ + A reference to a bound ISA in the enumeration context. + + This node looks up the binding from the context and yields the bound ISA. + + Args: + name: The name of the bound ISA to reference. + """ + + name: str + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Yields the bound ISA from the context. + + Args: + ctx (Context): The enumeration context containing bindings. + + Yields: + ISA: The bound ISA. + + Raises: + ValueError: If the name is not bound in the context. + """ + if self.name not in ctx._bindings: + raise ValueError(f"Undefined component reference: '{self.name}'") + yield ctx._bindings[self.name] + + def populate(self, ctx: ISAContext) -> int: + """Instructions already in graph from the bound component. + + Returns: + int: 1, since bound component nodes start at index 1. + """ + return 1 + + +@dataclass +class _BindingNode(ISAQuery): + """ + Enumeration node that binds a component to a name. + + This node enables the as_/ref pattern where multiple positions in the + enumeration tree share the same component instance. The bound component + is enumerated once, and its value is shared across all ISARefNodes with + the same name via the context. + + For multiple bindings, nest BindingNode instances. + + Args: + name: The name to bind the component to. + component: An EnumerationNode (e.g., _ComponentQuery) that produces the bound ISAs. + node: The child enumeration node that may contain ISARefNodes. + + Example: + + .. code-block:: python + ctx = EnumerationContext(architecture=arch) + + # Bind a code and reference it multiple times + BindingNode( + name="c", + component=ExampleErrorCorrection.q(), + node=ISARefNode("c") * ISARefNode("c"), + ).enumerate(ctx) + + # Multiple bindings via nesting + BindingNode( + name="c", + component=ExampleErrorCorrection.q(), + node=BindingNode( + name="f", + component=ExampleFactory.q(source=ISARefNode("c")), + node=ISARefNode("c") * ISARefNode("f"), + ), + ).enumerate(ctx) + """ + + name: str + component: ISAQuery + node: ISAQuery + + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: + """ + Enumerate child nodes with the bound component in context. + + Args: + ctx (Context): The enumeration context. + + Yields: + ISA: An ISA instance from the child node. + """ + # Enumerate all ISAs from the component node + for isa in self.component.enumerate(ctx): + # Add binding to context and enumerate child node + new_ctx = ctx._with_binding(self.name, isa) + yield from self.node.enumerate(new_ctx) + + def populate(self, ctx: ISAContext) -> int: + """Populate the graph from both the component and the child node. + + Returns: + int: The starting node index of the component's additions. + """ + comp_start = self.component.populate(ctx) + self.node.populate(ctx) + return comp_start diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py new file mode 100644 index 0000000000..2d1aaa7aa5 --- /dev/null +++ b/source/pip/qsharp/qre/_qre.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# flake8: noqa E402 +# pyright: reportAttributeAccessIssue=false + +from .._native import ( + _binom_ppf, + block_linear_function, + Block, + constant_function, + Constraint, + ConstraintBound, + _estimate_parallel, + _estimate_with_graph, + _EstimationCollection, + EstimationResult, + FactoryResult, + _FloatFunction, + generic_function, + instruction_name, + Instruction, + InstructionFrontier, + _IntFunction, + ISA, + ISARequirements, + _ProvenanceGraph, + linear_function, + LatticeSurgery, + PSSPC, + Trace, + property_name_to_key, + property_name, + _float_to_bits, + _float_from_bits, +) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi new file mode 100644 index 0000000000..da2a822063 --- /dev/null +++ b/source/pip/qsharp/qre/_qre.pyi @@ -0,0 +1,1673 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from typing import Any, Callable, Iterator, Optional, overload + +import pandas as pd + +class ISA: + def __add__(self, other: ISA) -> ISA: + """ + Concatenate two ISAs (logical union). Instructions in the second + operand overwrite instructions in the first operand if they have the + same ID. + """ + ... + + def __contains__(self, id: int) -> bool: + """ + Check if the ISA contains an instruction with the given ID. + + Args: + id (int): The instruction ID. + + Returns: + bool: True if the ISA contains an instruction with the given ID, False otherwise. + """ + ... + + def satisfies(self, requirements: ISARequirements) -> bool: + """ + Check if the ISA satisfies the given ISA requirements. + """ + ... + + def __getitem__(self, id: int) -> Instruction: + """ + Get an instruction by its ID. + + Args: + id (int): The instruction ID. + + Returns: + Instruction: The instruction. + """ + ... + + def get( + self, id: int, default: Optional[Instruction] = None + ) -> Optional[Instruction]: + """ + Get an instruction by its ID, or return a default value if not found. + + Args: + id (int): The instruction ID. + default (Optional[Instruction]): The default value to return if the + instruction is not found. + + Returns: + Optional[Instruction]: The instruction, or the default value if not found. + """ + ... + + def __len__(self) -> int: + """ + Return the number of instructions in the ISA. + + Returns: + int: The number of instructions. + """ + ... + + def node_index(self, id: int) -> Optional[int]: + """ + Return the provenance graph node index for the given instruction ID. + + Args: + id (int): The instruction ID. + + Returns: + Optional[int]: The node index, or None if not found. + """ + ... + + def add_node(self, instruction_id: int, node_index: int) -> None: + """ + Add a pre-existing provenance graph node to the ISA. + + Args: + instruction_id (int): The instruction ID. + node_index (int): The node index in the provenance graph. + """ + ... + + def as_frame(self) -> pd.DataFrame: + """ + Return a pandas DataFrame representation of the ISA. + + The DataFrame will have one row per instruction, with columns for + instruction properties such as time, space, and error rate. The exact + columns may vary based on the properties of the instructions in the ISA. + + Returns: + pd.DataFrame: A DataFrame representation of the ISA. + """ + ... + + def __iter__(self) -> Iterator[Instruction]: + """ + Return an iterator over the instructions. + + Note: + The order of instructions is not guaranteed. + + Returns: + Iterator[Instruction]: The instruction iterator. + """ + ... + + def __str__(self) -> str: + """ + Return a string representation of the ISA. + + Note: + The order of instructions in the output is not guaranteed. + + Returns: + str: A string representation of the ISA. + """ + ... + +class ISARequirements: + @overload + def __new__(cls, *constraints: Constraint) -> ISARequirements: ... + @overload + def __new__(cls, constraints: list[Constraint], /) -> ISARequirements: ... + def __new__(cls, *constraints: Constraint | list[Constraint]) -> ISARequirements: + """ + Create an ISA requirements specification from a list of instructions + constraints. + + Args: + constraints (list[Constraint] | *Constraint): The list of instruction + constraints. + """ + ... + + def __len__(self) -> int: + """ + Return the number of constraints in the requirements specification. + + Returns: + int: The number of constraints. + """ + ... + + def __iter__(self) -> Iterator[Constraint]: + """ + Return an iterator over the constraints. + + Note: + The order of constraints is not guaranteed. + + Returns: + Iterator[Constraint]: The constraint iterator. + """ + ... + + def as_frame(self) -> pd.DataFrame: + """ + Return a pandas DataFrame representation of the ISA requirements. + + The DataFrame will have one row per instruction, with columns for + constraint properties such as encoding. + + Returns: + pd.DataFrame: A DataFrame representation of the ISA requirements. + """ + ... + +class Instruction: + @staticmethod + def fixed_arity( + id: int, + encoding: int, + arity: int, + time: int, + space: Optional[int], + length: Optional[int], + error_rate: float, + ) -> Instruction: + """ + Create an instruction with a fixed arity. + + Note: + This function is not intended to be called directly by the user, use qre.instruction instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + arity (int): The instruction arity. + time (int): The instruction time in ns. + space (Optional[int]): The instruction space in number of physical + qubits. If None, length is used. + length (Optional[int]): The arity including ancilla qubits. If None, + arity is used. + error_rate (float): The instruction error rate. + + Returns: + Instruction: The instruction. + """ + ... + + @staticmethod + def variable_arity( + id: int, + encoding: int, + time_fn: _IntFunction, + space_fn: _IntFunction, + error_rate_fn: _FloatFunction, + length_fn: Optional[_IntFunction], + ) -> Instruction: + """ + Create an instruction with variable arity. + + Note: + This function is not intended to be called directly by the user, use qre.instruction instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + time_fn (_IntFunction): The time function. + space_fn (_IntFunction): The space function. + error_rate_fn (_FloatFunction): The error rate function. + length_fn (Optional[_IntFunction]): The length function. + If None, space_fn is used. + + Returns: + Instruction: The instruction. + """ + ... + + def with_id(self, id: int) -> Instruction: + """ + Return a copy of the instruction with the given ID. + + Note: + The created instruction will not inherit the source property of the + original instruction and must be set by the user if intended. + + Args: + id (int): The instruction ID. + + Returns: + Instruction: A copy of the instruction with the given ID. + """ + ... + + @property + def id(self) -> int: + """ + The instruction ID. + + Returns: + int: The instruction ID. + """ + ... + + @property + def encoding(self) -> int: + """ + The instruction encoding. 0 = Physical, 1 = Logical. + + Returns: + int: The instruction encoding. + """ + ... + + @property + def arity(self) -> Optional[int]: + """ + The instruction arity. + + Returns: + Optional[int]: The instruction arity. + """ + ... + + def space(self, arity: Optional[int] = None) -> Optional[int]: + """ + The instruction space in number of physical qubits. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[int]: The instruction space in number of physical qubits. + """ + ... + + def time(self, arity: Optional[int] = None) -> Optional[int]: + """ + The instruction time in ns. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[int]: The instruction time in ns. + """ + ... + + def error_rate(self, arity: Optional[int] = None) -> Optional[float]: + """ + The instruction error rate. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + Optional[float]: The instruction error rate. + """ + ... + + def expect_space(self, arity: Optional[int] = None) -> int: + """ + The instruction space in number of physical qubits. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + int: The instruction space in number of physical qubits. + """ + ... + + def expect_time(self, arity: Optional[int] = None) -> int: + """ + The instruction time in ns. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + int: The instruction time in ns. + """ + ... + + def expect_error_rate(self, arity: Optional[int] = None) -> float: + """ + The instruction error rate. Raises an error if not found. + + Args: + arity (Optional[int]): The specific arity to check. + + Returns: + float: The instruction error rate. + """ + ... + + def set_source(self, index: int) -> None: + """ + Set the source index for the instruction. + + Args: + index (int): The source index to set. + """ + ... + + @property + def source(self) -> int: + """ + Get the source index for the instruction. + + Returns: + int: The source index for the instruction. + """ + ... + + def set_property(self, key: int, value: int) -> None: + """ + Set a property on the instruction. + + Args: + key (int): The property key. + value (int): The property value. + """ + ... + + def get_property(self, key: int) -> Optional[int]: + """ + Get a property by its key. + + Args: + key (int): The property key. + + Returns: + Optional[int]: The property value, or None if not found. + """ + ... + + def has_property(self, key: int) -> bool: + """ + Check if the instruction has a property with the given key. + + Args: + key (int): The property key. + + Returns: + bool: True if the instruction has the property, False otherwise. + """ + ... + + def get_property_or(self, key: int, default: int) -> int: + """ + Get a property by its key, or return a default value if not found. + + Args: + key (int): The property key. + default (int): The default value to return if the property is not found. + + Returns: + int: The property value, or the default value if not found. + """ + ... + + def __getitem__(self, key: int) -> int: + """ + Get a property by its key, or raise an error if not found. + + Args: + key (int): The property key. + + Returns: + int: The property value. + """ + ... + + def __str__(self) -> str: + """ + Return a string representation of the instruction. + + Returns: + str: A string representation of the instruction. + """ + ... + +class ConstraintBound: + """ + A bound for a constraint. + """ + + @staticmethod + def lt(value: float) -> ConstraintBound: + """ + Create a less than constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def le(value: float) -> ConstraintBound: + """ + Create a less equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def eq(value: float) -> ConstraintBound: + """ + Create an equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def gt(value: float) -> ConstraintBound: + """ + Create a greater than constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + + @staticmethod + def ge(value: float) -> ConstraintBound: + """ + Create a greater equal constraint bound. + + Args: + value (float): The value. + + Returns: + ConstraintBound: The constraint bound. + """ + ... + +class Constraint: + """ + An instruction constraint that can be used to describe ISA requirements + for ISA transformations. + """ + + def __new__( + cls, + id: int, + encoding: int, + arity: Optional[int], + error_rate: Optional[ConstraintBound], + ) -> Constraint: + """ + Note: + This function is not intended to be called directly by the user, use qre.constraint instead. + + Args: + id (int): The instruction ID. + encoding (int): The instruction encoding. 0 = Physical, 1 = Logical. + arity (Optional[int]): The instruction arity. If None, instruction is + assumed to have variable arity. + error_rate (Optional[ConstraintBound]): The constraint on the error rate. + + Returns: + InstructionConstraint: The instruction constraint. + """ + ... + + @property + def id(self) -> int: + """ + The instruction ID. + + Returns: + int: The instruction ID. + """ + ... + + @property + def encoding(self) -> int: + """ + The instruction encoding. 0 = Physical, 1 = Logical. + + Returns: + int: The instruction encoding. + """ + ... + + @property + def arity(self) -> Optional[int]: + """ + The instruction arity. + + Returns: + Optional[int]: The instruction arity. + """ + ... + + @property + def error_rate(self) -> Optional[ConstraintBound]: + """ + The constraint on the instruction error rate. + + Returns: + Optional[ConstraintBound]: The constraint on the instruction error rate. + """ + ... + + def add_property(self, property: int) -> None: + """ + Add a property requirement to the constraint. + + Args: + property (int): The property key that must be present in matching instructions. + """ + ... + + def has_property(self, property: int) -> bool: + """ + Check if the constraint requires a specific property. + + Args: + property (int): The property key to check. + + Returns: + bool: True if the constraint requires this property, False otherwise. + """ + ... + +class _IntFunction: ... +class _FloatFunction: ... + +@overload +def constant_function(value: int) -> _IntFunction: ... +@overload +def constant_function(value: float) -> _FloatFunction: ... +def constant_function( + value: int | float, +) -> _IntFunction | _FloatFunction: + """ + Create a constant function. + + Args: + value (int | float): The constant value. + + Returns: + _IntFunction | _FloatFunction: The constant function. + """ + ... + +@overload +def linear_function(slope: int) -> _IntFunction: ... +@overload +def linear_function(slope: float) -> _FloatFunction: ... +def linear_function( + slope: int | float, +) -> _IntFunction | _FloatFunction: + """ + Create a linear function. + + Args: + slope (int | float): The slope. + + Returns: + _IntFunction | _FloatFunction: The linear function. + """ + ... + +@overload +def block_linear_function(block_size: int, slope: int, offset: int) -> _IntFunction: ... +@overload +def block_linear_function( + block_size: int, slope: float, offset: float +) -> _FloatFunction: ... +def block_linear_function( + block_size: int, slope: int | float, offset: int | float +) -> _IntFunction | _FloatFunction: + """ + Create a block linear function that takes an arity (number of qubits) as + input. Given an arity, it will compute the number of blocks `num_blocks` by + computing `ceil(arity / block_size)` and then return `slope * num_blocks + + offset`. + + Args: + block_size (int): The block size. + slope (int | float): The slope. + offset (int | float): The offset. + + Returns: + _IntFunction | _FloatFunction: The block linear function. + """ + ... + +@overload +def generic_function(func: Callable[[int], int]) -> _IntFunction: ... +@overload +def generic_function(func: Callable[[int], float]) -> _FloatFunction: ... +def generic_function( + func: Callable[[int], int | float], +) -> _IntFunction | _FloatFunction: + """ + Create a generic function from a Python callable. + + Note: + Only use this function if the other function constructors + (constant_function, linear_function, and block_linear_function) do not + meet your needs, as using a Python callable can have performance + implications. If using this function, keep the logic in the callable as + simple as possible to minimize overhead. + + Args: + func (Callable[[int], int | float]): The Python callable. + + Returns: + _IntFunction | _FloatFunction: The generic function. + """ + ... + +class _ProvenanceGraph: + """ + Represents the provenance graph of instructions in a trace. Each node in + the graph corresponds to an instruction and the transform from which it was + produced, and edges represent transformations applied to instructions during + enumeration. + """ + + def add_node( + self, instruction: Instruction, transform_id: int, children: list[int] + ) -> int: + """ + Add a node to the provenance graph. + + Args: + instruction (int): The instruction corresponding to the node. + transform_id (int): The transform ID corresponding to the node. + children (list[int]): The list of child node indices in the provenance graph. + + Returns: + int: The index of the added node in the provenance graph. + """ + ... + + def instruction(self, node_index: int) -> Instruction: + """ + Return the instruction for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + int: The instruction corresponding to the node. + """ + ... + + def transform_id(self, node_index: int) -> int: + """ + Return the transform ID for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + int: The transform ID corresponding to the node. + """ + ... + + def children(self, node_index: int) -> list[int]: + """ + Return the list of child node indices for a given node index. + + Args: + node_index (int): The index of the node in the provenance graph. + + Returns: + list[int]: The list of child node indices. + """ + ... + + def num_nodes(self) -> int: + """ + Return the number of nodes in the provenance graph. + + Returns: + int: The number of nodes in the provenance graph. + """ + ... + + def num_edges(self) -> int: + """ + Return the number of edges in the provenance graph. + + Returns: + int: The number of edges in the provenance graph. + """ + ... + + @overload + def add_instruction( + self, + instruction: Instruction, + ) -> int: ... + @overload + def add_instruction( + self, + id: int, + encoding: int = 0, + *, + arity: Optional[int] = 1, + time: int | _IntFunction = ..., + space: Optional[int | _IntFunction] = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = ..., + **kwargs: int, + ) -> int: ... + def add_instruction( + self, + id_or_instruction: int | Instruction, + encoding: int = 0, + *, + arity: Optional[int] = 1, + time: int | _IntFunction = ..., + space: Optional[int | _IntFunction] = None, + length: Optional[int | _IntFunction] = None, + error_rate: float | _FloatFunction = ..., + **kwargs: int, + ) -> int: + """ + Add an instruction to the provenance graph with no transform or + children. + + Can be called with a pre-existing ``Instruction`` or with keyword + args to create one inline. + + Args: + id_or_instruction: An instruction ID (int) or ``Instruction``. + encoding: 0 = Physical, 1 = Logical. Ignored for ``Instruction``. + arity: Instruction arity, ``None`` for variable. Ignored for + ``Instruction``. + time: Time in ns (or ``_IntFunction``). Ignored for ``Instruction``. + space: Space in physical qubits (or ``_IntFunction``). Ignored for + ``Instruction``. + length: Arity including ancillas. Ignored for ``Instruction``. + error_rate: Error rate (or ``_FloatFunction``). Ignored for + ``Instruction``. + **kwargs: Additional properties (e.g. ``distance=9``). + + Returns: + int: The node index of the added instruction. + """ + ... + + def make_isa(self, node_indices: list[int]) -> ISA: + """ + Create an ISA backed by this provenance graph from the given node + indices. + + Args: + node_indices: A list of node indices in the provenance graph. + + Returns: + ISA: An ISA referencing this provenance graph. + """ + ... + + def build_pareto_index(self) -> None: + """ + Builds the per-instruction-ID Pareto index. + + For each instruction ID, retains only the Pareto-optimal nodes w.r.t. + (space, time, error_rate) evaluated at arity 1. Must be called after + all nodes have been added. + """ + ... + + def query_satisfying( + self, + requirements: ISARequirements, + min_node_idx: Optional[int] = None, + ) -> list[ISA]: + """ + Return ISAs formed from Pareto-optimal graph nodes satisfying the + given requirements. + + For each constraint in requirements, selects matching Pareto-optimal + nodes. Returns the Cartesian product of per-constraint matches, + augmented with one representative node per unconstrained instruction + ID. + + Must be called after ``build_pareto_index``. + + Args: + requirements: The ISA requirements to satisfy. + min_node_idx: If provided, only consider Pareto nodes at or above + this index for constrained groups. + + Returns: + list[ISA]: ISAs formed from matching Pareto-optimal nodes. + """ + ... + + def raw_node_count(self) -> int: + """ + Return the raw node count (including the sentinel at index 0). + + Returns: + int: The number of nodes in the graph. + """ + ... + + def total_isa_count(self) -> int: + """ + Return the total number of ISAs that can be formed from Pareto-optimal + nodes. + + Requires ``build_pareto_index`` to have been called. + + Returns: + int: The total number of ISAs that can be formed. + """ + ... + +class EstimationResult: + """ + Represents the result of a resource estimation. + """ + + def __new__( + cls, *, qubits: int = 0, runtime: int = 0, error: float = 0.0 + ) -> EstimationResult: + """ + Create a new estimation result. + + Args: + qubits (int): The number of logical qubits. + runtime (int): The runtime in nanoseconds. + error (float): The error probability of the computation. + + Returns: + EstimationResult: The estimation result. + """ + ... + + @property + def qubits(self) -> int: + """ + The number of logical qubits. + + Returns: + int: The number of logical qubits. + """ + ... + + @qubits.setter + def qubits(self, qubits: int) -> None: + """ + Set the number of logical qubits. + + Args: + qubits (int): The number of logical qubits to set. + """ + ... + + @property + def runtime(self) -> int: + """ + The runtime in nanoseconds. + + Returns: + int: The runtime in nanoseconds. + """ + ... + + @runtime.setter + def runtime(self, runtime: int) -> None: + """ + Set the runtime. + + Args: + runtime (int): The runtime in nanoseconds to set. + """ + ... + + @property + def error(self) -> float: + """ + The error probability of the computation. + + Returns: + float: The error probability of the computation. + """ + ... + + @error.setter + def error(self, error: float) -> None: + """ + Set the error probability. + + Args: + error (float): The error probability to set. + """ + ... + + @property + def factories(self) -> dict[int, FactoryResult]: + """ + The factory results. + + Returns: + dict[int, FactoryResult]: A dictionary mapping factory IDs to their results. + """ + ... + + @property + def isa(self) -> ISA: + """ + The ISA used for the estimation. + + Returns: + ISA: The ISA used for the estimation. + """ + ... + + @property + def properties(self) -> dict[int, bool | int | float | str]: + """ + Custom properties from application generation and trace transform. + + Returns: + dict[int, bool | int | float | str]: A dictionary mapping property keys to their values. + """ + ... + + def set_property(self, key: int, value: bool | int | float | str) -> None: + """ + Set a custom property. + + Args: + key (int) The property key. + value (bool | int | float | str): The property value. All values of type `int`, `float`, `bool`, and `str` + are supported. Any other value is converted to a string using its `__str__` method. + """ + ... + + def __str__(self) -> str: + """ + Return a string representation of the estimation result. + + Returns: + str: A string representation of the estimation result. + """ + ... + +class _EstimationCollection: + """ + Represents a collection of estimation results. Results are stored as a 2D + Pareto frontier with physical qubits and runtime as objectives. + """ + + def __new__(cls) -> _EstimationCollection: + """ + Create a new estimation collection. + + Returns: + _EstimationCollection: The estimation collection. + """ + ... + + def insert(self, result: EstimationResult) -> None: + """ + Insert an estimation result into the collection. + + Args: + result (EstimationResult): The estimation result to insert. + """ + ... + + def __len__(self) -> int: + """ + Return the number of estimation results in the collection. + + Returns: + int: The number of estimation results. + """ + ... + + def __iter__(self) -> Iterator[EstimationResult]: + """ + Return an iterator over the estimation results. + + Returns: + Iterator[EstimationResult]: The estimation result iterator. + """ + ... + + @property + def total_jobs(self) -> int: + """ + Return the total number of (trace, ISA) estimation jobs. + + Returns: + int: The total number of jobs. + """ + ... + + @property + def successful_estimates(self) -> int: + """ + Return the number of estimation jobs that completed successfully + (before Pareto filtering). + + Returns: + int: The number of successful estimates. + """ + ... + + @property + def all_summaries(self) -> list[tuple[int, int, int, int]]: + """ + Return lightweight summaries of ALL successful estimates as a list + of (trace_index, isa_index, qubits, runtime) tuples. + + Returns: + list[tuple[int, int, int, int]]: List of (trace_index, isa_index, + qubits, runtime) for every successful estimation. + """ + ... + + @property + def isas(self) -> list[ISA]: + """ + Return the list of ISAs for which estimates were performed. + + Returns: + list[ISA]: The list of ISAs. + """ + ... + +class FactoryResult: + """ + Represents the result of a factory used in resource estimation. + """ + + @property + def copies(self) -> int: + """ + The number of factory copies. + + Returns: + int: The number of factory copies. + """ + ... + + @property + def runs(self) -> int: + """ + The number of factory runs. + + Returns: + int: The number of factory runs. + """ + ... + + @property + def error_rate(self) -> float: + """ + The error rate of the factory. + + Returns: + float: The error rate of the factory. + """ + ... + + @property + def states(self) -> int: + """ + The number of states produced by the factory. + + Returns: + int: The number of states produced by the factory. + """ + ... + +class Trace: + """ + Represents a quantum program optimized for resource estimation. + + A trace originates from a quantum application and can be modified via trace + transformations. It consists of blocks of operations. + """ + + def __new__(cls, compute_qubits: int) -> Trace: + """ + Create a new trace. + + Returns: + Trace: The trace. + """ + ... + + def clone_empty(self, compute_qubits: Optional[int] = None) -> Trace: + """ + Create a new trace with the same metadata but empty block. + + Args: + compute_qubits (Optional[int]): The number of compute qubits. If None, + the number of compute qubits of the original trace is used. + + Returns: + Trace: The new trace. + """ + ... + + @classmethod + def from_json(cls, json: str) -> Trace: + """ + Create a trace from a JSON string. + + Args: + json (str): The JSON string. + + Returns: + Trace: The trace. + """ + ... + + def to_json(self) -> str: + """ + Serializes the trace to a JSON string. + + Returns: + str: The JSON string representation of the trace. + """ + ... + + @property + def compute_qubits(self) -> int: + """ + The number of compute qubits. + + Returns: + int: The number of compute qubits. + """ + ... + + @compute_qubits.setter + def compute_qubits(self, qubits: int) -> None: + """ + Set the number of compute qubits. + + Args: + qubits (int): The number of compute qubits to set. + """ + ... + + @property + def base_error(self) -> float: + """ + The base error of the trace. + + Returns: + float: The base error of the trace. + """ + ... + + def increment_base_error(self, amount: float) -> None: + """ + Increments the base error. + + Args: + amount (float): The amount to increment. + """ + ... + + @property + def memory_qubits(self) -> Optional[int]: + """ + The number of memory qubits, if set. + + Returns: + Optional[int]: The number of memory qubits, or None if not set. + """ + ... + + def has_memory_qubits(self) -> bool: + """ + Check if the trace has memory qubits set. + + Returns: + bool: True if memory qubits are set, False otherwise. + """ + ... + + @memory_qubits.setter + def memory_qubits(self, qubits: int) -> None: + """ + Set the number of memory qubits. + + Args: + qubits (int): The number of memory qubits. + """ + ... + + def increment_memory_qubits(self, amount: int) -> None: + """ + Increments the number of memory qubits. If memory qubits have not been + set, initializes them to 0 before incrementing. + + Args: + amount (int): The amount to increment. + """ + ... + + def increment_resource_state(self, resource_id: int, amount: int) -> None: + """ + Increments a resource state count. + + Args: + resource_id (int): The resource state ID. + amount (int): The amount to increment. + """ + ... + + def set_property(self, key: int, value: Any) -> None: + """ + Set a property. All values of type `int`, `float`, `bool`, and `str` + are supported. Any other value is converted to a string using its + `__str__` method. + + Args: + key (int): The property key. + value (Any): The property value. + """ + ... + + def get_property(self, key: int) -> Optional[int | float | bool | str]: + """ + Get a property. + + Args: + key (int): The property key. + + Returns: + Optional[int | float | bool | str]: The property value, or None if not found. + """ + ... + + def has_property(self, key: int) -> bool: + """ + Check if a property with the given key exists. + + Args: + key (int): The property key. + + Returns: + bool: True if the property exists, False otherwise. + """ + ... + + @property + def total_qubits(self) -> int: + """ + The total number of qubits (compute + memory). + + Returns: + int: The total number of qubits. + """ + ... + + @property + def depth(self) -> int: + """ + The trace depth. + + Returns: + int: The trace depth. + """ + ... + + @property + def num_gates(self) -> int: + """ + The total number of gates in the trace. + + Returns: + int: The total number of gates. + """ + ... + + def estimate( + self, isa: ISA, max_error: Optional[float] = None + ) -> Optional[EstimationResult]: + """ + Estimate resources for the trace given a logical ISA. + + Args: + isa (ISA): The logical ISA. + max_error (Optional[float]): The maximum allowed error. If None, + Pareto points are computed. + + Returns: + Optional[EstimationResult]: The estimation result if max_error is + provided, otherwise valid Pareto points. + """ + ... # The implementation in Rust returns Option, so it fits + + @property + def resource_states(self) -> dict[int, int]: + """ + The resource states used in the trace. + + Returns: + dict[int, int]: A dictionary mapping instruction IDs to their counts. + """ + ... + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Add an operation to the trace. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def root_block(self) -> Block: + """ + Return the root block of the trace. + + Returns: + Block: The root block of the trace. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Add a block to the trace. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + @property + def required_isa(self) -> ISARequirements: + """ + The required ISA for the trace. + + Returns: + ISARequirements: The required ISA for the trace. + """ + ... + + def __str__(self) -> str: + """ + Return a string representation of the trace. + + Returns: + str: A string representation of the trace. + """ + ... + +class Block: + """ + Represents a block of operations in a trace. + + An operation in a block can either refer to an instruction applied to some + qubits or can be another block to create a hierarchical structure. Blocks + can be repeated. + """ + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Add an operation to the block. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Add a nested block to the block. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + def __str__(self) -> str: + """ + Return a string representation of the block. + + Returns: + str: A string representation of the block. + """ + ... + +class PSSPC: + def __new__(cls, num_ts_per_rotation: int, ccx_magic_states: bool) -> PSSPC: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class LatticeSurgery: + def __new__(cls, slow_down_factor: float) -> LatticeSurgery: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class InstructionFrontier: + """ + Represents a Pareto frontier of instructions with space, time, and error + rates as objectives. + """ + + def __new__(cls, *, with_error_objective: bool = True) -> InstructionFrontier: + """ + Create a new instruction frontier. + + Args: + with_error_objective (bool): If True (default), the frontier uses + three objectives (space, time, error rate). If False, it uses + two objectives (space, time). + """ + ... + + def insert(self, point: Instruction): + """ + Insert an instruction into the frontier. + + Args: + point (Instruction): The instruction to insert. + """ + ... + + def extend(self, points: list[Instruction]) -> None: + """ + Extend the frontier with a list of instructions. + + Args: + points (list[Instruction]): The instructions to insert. + """ + ... + + def __len__(self) -> int: + """ + Return the number of instructions in the frontier. + + Returns: + int: The number of instructions. + """ + ... + + def __iter__(self) -> Iterator[Instruction]: + """ + Return an iterator over the instructions in the frontier. + + Returns: + Iterator[Instruction]: The iterator. + """ + ... + + @staticmethod + def load( + filename: str, *, with_error_objective: bool = True + ) -> InstructionFrontier: + """ + Load an instruction frontier from a file. + + Args: + filename (str): The file name. + with_error_objective (bool): If True (default), the frontier uses + three objectives (space, time, error rate). If False, it uses + two objectives (space, time). + + Returns: + InstructionFrontier: The loaded instruction frontier. + """ + ... + + def dump(self, filename: str) -> None: + """ + Dump the instruction frontier to a file. + + Args: + filename (str): The file name. + """ + ... + +def _estimate_parallel( + traces: list[Trace], + isas: list[ISA], + max_error: float = 1.0, + post_process: bool = False, +) -> _EstimationCollection: + """ + Estimate resources for multiple traces and ISAs in parallel. + + Args: + traces (list[Trace]): The list of traces. + isas (list[ISA]): The list of ISAs. + max_error (float): The maximum allowed error. The default is 1.0. + post_process (bool): If True, computes auxiliary data such as result + summaries needed for post-processing after estimation. + + Returns: + _EstimationCollection: The estimation collection. + """ + ... + +def _estimate_with_graph( + traces: list[Trace], + graph: _ProvenanceGraph, + max_error: float = 1.0, + post_process: bool = False, +) -> _EstimationCollection: + """ + Estimate resources using a Pareto-filtered provenance graph. + + Instead of forming the full Cartesian product of ISAs × traces, this + function enumerates per-trace instruction combinations from the + Pareto-optimal subsets in the frozen graph. + + Args: + traces (list[Trace]): The list of traces to estimate. + graph (_ProvenanceGraph): The provenance graph to use for estimation. + max_error (float): The maximum allowed error. The default is 1.0. + post_process (bool): If True, computes auxiliary data such as result + summaries and ISAs needed for post-processing after estimation. + + Returns: + _EstimationCollection: The estimation collection. + """ + ... + +def _binom_ppf(q: float, n: int, p: float) -> int: + """ + A replacement for SciPy's binom.ppf that is faster and does not require + SciPy as a dependency. + """ + ... + +def _float_to_bits(f: float) -> int: + """Convert a float to its bit representation as an integer.""" + ... + +def _float_from_bits(b: int) -> float: + """Convert a float from its bit representation as an integer.""" + ... + +def instruction_name(id: int) -> Optional[str]: + """ + Return the name of an instruction given its ID, if known. + + Args: + id (int): The instruction ID. + + Returns: + Optional[str]: The name of the instruction, or None if the ID is not recognized. + """ + ... + +def property_name_to_key(name: str) -> Optional[int]: + """ + Convert a property name to its corresponding key, if known. + + Args: + name (str): The property name. + + Returns: + Optional[int]: The property key, or None if the name is not recognized. + """ + ... + +def property_name(id: int) -> Optional[str]: + """ + Convert a property key to its corresponding name, if known. + + Args: + id (int): The property key. + + Returns: + Optional[str]: The property name, or None if the key is not recognized. + """ + ... diff --git a/source/pip/qsharp/qre/_results.py b/source/pip/qsharp/qre/_results.py new file mode 100644 index 0000000000..1a47a0d975 --- /dev/null +++ b/source/pip/qsharp/qre/_results.py @@ -0,0 +1,394 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Callable, Any, Iterable + +import pandas as pd + +from ._architecture import ISAContext +from ._qre import ( + FactoryResult, + instruction_name, + EstimationResult, +) +from ._instruction import InstructionSource +from .property_keys import ( + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_MEMORY_QUBITS, + PHYSICAL_FACTORY_QUBITS, +) + + +class EstimationTable(list["EstimationTableEntry"]): + """A table of quantum resource estimation results. + + Extends ``list[EstimationTableEntry]`` and provides configurable columns for + displaying estimation data. By default the table includes *qubits*, + *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. + Additional columns can be added or inserted with ``add_column`` and + ``insert_column``. + """ + + def __init__(self): + """Initialize an empty estimation table with default columns.""" + super().__init__() + + self.name: Optional[str] = None + self.stats = EstimationTableStats() + + self._columns: list[tuple[str, EstimationTableColumn]] = [ + ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), + ( + "runtime", + EstimationTableColumn( + lambda entry: entry.runtime, + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ), + ), + ("error", EstimationTableColumn(lambda entry: entry.error)), + ] + + def add_column( + self, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Add a column to the estimation table. + + Args: + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of ``function`` for display purposes. + """ + self._columns.append((name, EstimationTableColumn(function, formatter))) + + def insert_column( + self, + index: int, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Insert a column at the specified index in the estimation table. + + Args: + index (int): The index at which to insert the column. + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of ``function`` for display purposes. + """ + self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) + + def add_qubit_partition_column(self) -> None: + """Add columns for the physical compute, factory, and memory qubit counts.""" + self.add_column( + "physical_compute_qubits", + lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), + ) + self.add_column( + "physical_factory_qubits", + lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), + ) + self.add_column( + "physical_memory_qubits", + lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), + ) + + def add_factory_summary_column(self) -> None: + """Add a column to the estimation table that summarizes the factories used in the estimation.""" + + def summarize_factories(entry: EstimationTableEntry) -> str: + if not entry.factories: + return "None" + return ", ".join( + f"{factory_result.copies}×{instruction_name(id)}" + for id, factory_result in entry.factories.items() + ) + + self.add_column("factories", summarize_factories) + + def as_frame(self): + """Convert the estimation table to a ``pandas.DataFrame``. + + Each row corresponds to an ``EstimationTableEntry`` and each + column is determined by the columns registered on this table. Column + formatters, when present, are applied to the values before they are + placed in the frame. + + Returns: + pandas.DataFrame: A DataFrame representation of the estimation + results. + """ + return pd.DataFrame( + [ + { + column_name: ( + column.formatter(column.function(entry)) + if column.formatter is not None + else column.function(entry) + ) + for column_name, column in self._columns + } + for entry in self + ] + ) + + def plot(self, **kwargs): + """Plot this table's results. + + Convenience wrapper around ``plot_estimates``. All keyword + arguments are forwarded. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + """ + return plot_estimates(self, **kwargs) + + +@dataclass(frozen=True, slots=True) +class EstimationTableColumn: + """Definition of a single column in an ``EstimationTable``. + + Attributes: + function: A callable that extracts the raw column value from an + ``EstimationTableEntry``. + formatter: An optional callable that transforms the raw value for + display purposes (e.g. converting nanoseconds to a + ``pandas.Timedelta``). + """ + + function: Callable[[EstimationTableEntry], Any] + formatter: Optional[Callable[[Any], Any]] = None + + +@dataclass(frozen=True, slots=True) +class EstimationTableEntry: + """A single row in an ``EstimationTable``. + + Each entry represents one Pareto-optimal estimation result for a + particular combination of application trace and architecture ISA. + + Attributes: + qubits: Total number of physical qubits required. + runtime: Total runtime of the algorithm in nanoseconds. + error: Total estimated error probability. + source: The instruction source derived from the architecture ISA used + for this estimation. + factories: A mapping from instruction id to the + ``FactoryResult`` describing the magic-state factory used + and the number of copies required. + properties: Additional key-value properties attached to the + estimation result. + """ + + qubits: int + runtime: int + error: float + source: InstructionSource + factories: dict[int, FactoryResult] = field(default_factory=dict) + properties: dict[int, int | float | bool | str] = field(default_factory=dict) + + @classmethod + def from_result( + cls, result: EstimationResult, ctx: ISAContext + ) -> EstimationTableEntry: + """Create an entry from an estimation result and architecture context. + + Args: + result (EstimationResult): The raw estimation result. + ctx (ISAContext): The architecture context used for the estimation. + + Returns: + EstimationTableEntry: A new table entry populated from the result. + """ + return cls( + qubits=result.qubits, + runtime=result.runtime, + error=result.error, + source=InstructionSource.from_isa(ctx, result.isa), + factories=result.factories.copy(), + properties=result.properties.copy(), + ) + + +@dataclass(slots=True) +class EstimationTableStats: + """Statistics for a single estimation run. + + Attributes: + num_traces (int): Number of traces evaluated. + num_isas (int): Number of ISAs evaluated. + total_jobs (int): Total estimation jobs executed. + successful_estimates (int): Number of jobs that produced a result. + pareto_results (int): Number of Pareto-optimal results retained. + """ + + num_traces: int = 0 + num_isas: int = 0 + total_jobs: int = 0 + successful_estimates: int = 0 + pareto_results: int = 0 + + +# Mapping from runtime unit name to its value in nanoseconds. +_TIME_UNITS: dict[str, float] = { + "ns": 1, + "µs": 1e3, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + "min": 60e9, + "hours": 3600e9, + "days": 86_400e9, + "weeks": 604_800e9, + "months": 31 * 86_400e9, + "years": 365 * 86_400e9, + "decades": 10 * 365 * 86_400e9, + "centuries": 100 * 365 * 86_400e9, +} + +# Ordered subset of _TIME_UNITS used for default x-axis tick labels. +_TICK_UNITS: list[tuple[str, float]] = [ + ("1 ns", _TIME_UNITS["ns"]), + ("1 µs", _TIME_UNITS["µs"]), + ("1 ms", _TIME_UNITS["ms"]), + ("1 s", _TIME_UNITS["s"]), + ("1 min", _TIME_UNITS["min"]), + ("1 hour", _TIME_UNITS["hours"]), + ("1 day", _TIME_UNITS["days"]), + ("1 week", _TIME_UNITS["weeks"]), + ("1 month", _TIME_UNITS["months"]), + ("1 year", _TIME_UNITS["years"]), + ("1 decade", _TIME_UNITS["decades"]), + ("1 century", _TIME_UNITS["centuries"]), +] + + +def plot_estimates( + data: EstimationTable | Iterable[EstimationTable], + *, + runtime_unit: Optional[str] = None, + figsize: tuple[float, float] = (15, 8), + scatter_args: dict[str, Any] = {"marker": "x"}, +): + """Plot estimation results displaying qubits vs runtime. + + Creates a log-log scatter plot where the x-axis shows the total runtime and + the y-axis shows the total number of physical qubits. + + *data* may be a single ``EstimationTable`` or an iterable of tables. When + multiple tables are provided, each is plotted as a separate series. If a + table has a ``EstimationTable.name`` (set via the *name* parameter of + ``estimate``), it is used as the legend label for that series. + + When *runtime_unit* is ``None`` (the default), the x-axis uses + human-readable time-unit tick labels spanning nanoseconds to centuries. + When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to + that unit and the x-axis label includes the unit while the ticks are plain + numbers. + + Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, + ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, + ``"years"``. + + Args: + data: A single EstimationTable or an iterable of + EstimationTable objects to plot. + runtime_unit: Optional time unit to scale the x-axis to. + figsize: Figure dimensions in inches as ``(width, height)``. + scatter_args: Additional keyword arguments to pass to + ``matplotlib.axes.Axes.scatter`` when plotting the points. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + + Raises: + ImportError: If matplotlib is not installed. + ValueError: If all tables are empty or *runtime_unit* is not + recognised. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "Missing optional 'matplotlib' dependency. To install run: " + "pip install matplotlib" + ) + + # Normalize to a list of tables + if isinstance(data, EstimationTable): + tables = [data] + else: + tables = list(data) + + if not tables or all(len(t) == 0 for t in tables): + raise ValueError("Cannot plot an empty EstimationTable.") + + if runtime_unit is not None and runtime_unit not in _TIME_UNITS: + raise ValueError( + f"Unknown runtime_unit {runtime_unit!r}. " + f"Supported units: {', '.join(_TIME_UNITS)}" + ) + + fig, ax = plt.subplots(figsize=figsize) + ax.set_ylabel("Physical qubits") + ax.set_xscale("log") + ax.set_yscale("log") + + all_xs: list[float] = [] + has_labels = False + + for table in tables: + if len(table) == 0: + continue + + ys = [entry.qubits for entry in table] + + if runtime_unit is not None: + scale = _TIME_UNITS[runtime_unit] + xs = [entry.runtime / scale for entry in table] + else: + xs = [float(entry.runtime) for entry in table] + + all_xs.extend(xs) + + label = table.name + if label is not None: + has_labels = True + + ax.scatter(x=xs, y=ys, label=label, **scatter_args) + + if runtime_unit is not None: + ax.set_xlabel(f"Runtime ({runtime_unit})") + else: + ax.set_xlabel("Runtime") + + time_labels, time_units = zip(*_TICK_UNITS) + + cutoff = ( + next( + (i for i, x in enumerate(time_units) if x > max(all_xs)), + len(time_units) - 1, + ) + + 1 + ) + + ax.set_xticks(time_units[:cutoff]) + ax.set_xticklabels(time_labels[:cutoff], rotation=90) + + if has_labels: + ax.legend() + + plt.close(fig) + + return fig diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py new file mode 100644 index 0000000000..965454ac95 --- /dev/null +++ b/source/pip/qsharp/qre/_trace.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass, KW_ONLY, field +from itertools import product +from types import NoneType +from typing import Any, Optional, Generator, Type, TYPE_CHECKING + +if TYPE_CHECKING: + from ._application import _Context +from ._enumeration import _enumerate_instances +from ._qre import PSSPC as _PSSPC, LatticeSurgery as _LatticeSurgery, Trace + + +class TraceTransform(ABC): + """Abstract base class for trace transformations.""" + + @abstractmethod + def transform(self, trace: Trace) -> Optional[Trace]: + """Apply this transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace, or None if the + transformation is not applicable. + """ + ... + + @classmethod + def q(cls, **kwargs) -> TraceQuery: + """Create a trace query for this transform type. + + Args: + **kwargs: Domain overrides for parameter enumeration. + + Returns: + TraceQuery: A trace query wrapping this transform type. + """ + return TraceQuery(cls, **kwargs) + + +@dataclass +class PSSPC(TraceTransform): + """Pauli-based computation trace transform (PSSPC). + + Converts rotation gates and optionally CCX gates into T-state-based + operations suitable for lattice surgery resource estimation. + + Attributes: + num_ts_per_rotation (int): Number of T states used per rotation + gate. Default is 20. + ccx_magic_states (bool): If True, CCX gates are treated as magic + states rather than being decomposed into T gates. Default is + False. + """ + + _: KW_ONLY + num_ts_per_rotation: int = field( + default=20, metadata={"domain": list(range(5, 21))} + ) + ccx_magic_states: bool = field(default=False) + + def __post_init__(self): + self._psspc = _PSSPC(self.num_ts_per_rotation, self.ccx_magic_states) + + def transform(self, trace: Trace) -> Optional[Trace]: + """Apply the PSSPC transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace. + """ + return self._psspc.transform(trace) + + +@dataclass +class LatticeSurgery(TraceTransform): + """Lattice surgery trace transform. + + Converts a trace into a form suitable for lattice-surgery-based + resource estimation. + + Attributes: + slow_down_factor (float): Multiplicative factor applied to the + trace depth. Default is 1.0. + """ + + _: KW_ONLY + slow_down_factor: float = field(default=1.0, metadata={"domain": [1.0]}) + + def __post_init__(self): + self._lattice_surgery = _LatticeSurgery(self.slow_down_factor) + + def transform(self, trace: Trace) -> Optional[Trace]: + """Apply the lattice surgery transformation to a trace. + + Args: + trace (Trace): The input trace. + + Returns: + Optional[Trace]: The transformed trace. + """ + return self._lattice_surgery.transform(trace) + + +class _Node(ABC): + """Abstract base class for trace enumeration nodes.""" + + @abstractmethod + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: ... + + +class TraceQuery(_Node): + """A query that enumerates transformed traces from an application. + + A trace query chains a sequence of trace transforms, each with optional + keyword arguments to override their default parameter domains. + """ + + # This is a sequence of trace transforms together with possible kwargs to + # override their default domains. The first element might be + sequence: list[tuple[Type, dict[str, Any]]] + + def __init__(self, t: Type, **kwargs): + self.sequence = [(t, kwargs)] + + def enumerate( + self, ctx: _Context, track_parameters: bool = False + ) -> Generator[Trace | tuple[Any, Trace], None, None]: + """Enumerate transformed traces from the application context. + + Args: + ctx (_Context): The application enumeration context. + track_parameters (bool): If True, yield ``(parameters, trace)`` + tuples instead of plain traces. Default is False. + + Yields: + Trace | tuple[Any, Trace]: A transformed trace, or a + ``(parameters, trace)`` tuple when *track_parameters* is True. + """ + sequence = self.sequence + kwargs = {} + if len(sequence) > 0 and sequence[0][0] is NoneType: + kwargs = sequence[0][1] + sequence = sequence[1:] + + if track_parameters: + source = ctx.application.enumerate_traces_with_parameters(**kwargs) + else: + source = ((None, t) for t in ctx.application.enumerate_traces(**kwargs)) + + for params, trace in source: + if not sequence: + yield (params, trace) if track_parameters else trace + continue + + transformer_instances = [] + + for t, transformer_kwargs in sequence: + instances = _enumerate_instances(t, **transformer_kwargs) + transformer_instances.append(instances) + + # TODO: make parallel + for combination in product(*transformer_instances): + transformed = trace + for transformer in combination: + transformed = transformer.transform(transformed) + yield (params, transformed) if track_parameters else transformed + + def __mul__(self, other: TraceQuery) -> TraceQuery: + """Chain another trace query onto this one. + + Args: + other (TraceQuery): The trace query to append. + + Returns: + TraceQuery: A new query with the combined transform sequence. + + Raises: + ValueError: If *other* begins with a None transform. + """ + new_query = TraceQuery.__new__(TraceQuery) + + if len(other.sequence) > 0 and other.sequence[0][0] is NoneType: + raise ValueError( + "Cannot multiply with a TraceQuery that has a None transform at the beginning of its sequence." + ) + + new_query.sequence = self.sequence + other.sequence + return new_query diff --git a/source/pip/qsharp/qre/application/__init__.py b/source/pip/qsharp/qre/application/__init__.py new file mode 100644 index 0000000000..7e39460c7e --- /dev/null +++ b/source/pip/qsharp/qre/application/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._cirq import CirqApplication +from ._qsharp import QSharpApplication + +__all__ = ["CirqApplication", "QSharpApplication"] diff --git a/source/pip/qsharp/qre/application/_cirq.py b/source/pip/qsharp/qre/application/_cirq.py new file mode 100644 index 0000000000..f2bc7272f3 --- /dev/null +++ b/source/pip/qsharp/qre/application/_cirq.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import annotations + +from dataclasses import dataclass + +import cirq + +from .._application import Application +from .._qre import Trace +from ..interop import trace_from_cirq + + +@dataclass +class CirqApplication(Application[None]): + """Application that produces a resource estimation trace from a Cirq circuit. + + Accepts either a Cirq ``Circuit`` object or an OpenQASM string. When a + QASM string is provided, it is parsed into a circuit using + ``cirq.contrib.qasm_import`` (requires the optional ``ply`` dependency). + + Args: + circuit_or_qasm: A Cirq Circuit or an OpenQASM string. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. Defaults to 0.5. + """ + + circuit_or_qasm: str | cirq.CIRCUIT_LIKE + classical_control_probability: float = 0.5 + + def __post_init__(self): + if isinstance(self.circuit_or_qasm, str): + try: + from cirq.contrib.qasm_import import circuit_from_qasm + + self._circuit = circuit_from_qasm(self.circuit_or_qasm) + except ImportError: + raise ImportError( + "Missing optional 'ply' dependency. To install run: " + "pip install ply" + ) + else: + self._circuit = self.circuit_or_qasm + + def get_trace(self, parameters: None = None) -> Trace: + """Return the resource estimation trace for the Cirq circuit. + + Args: + parameters (None): Unused. Defaults to None. + + Returns: + Trace: The resource estimation trace. + """ + return trace_from_cirq(self._circuit) diff --git a/source/pip/qsharp/qre/application/_qsharp.py b/source/pip/qsharp/qre/application/_qsharp.py new file mode 100644 index 0000000000..abdad2bce4 --- /dev/null +++ b/source/pip/qsharp/qre/application/_qsharp.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import annotations + +from pathlib import Path +from dataclasses import dataclass, field +from typing import Callable + +from ...estimator import LogicalCounts +from .._qre import Trace +from .._application import Application +from ..interop import trace_from_entry_expr_cached + + +@dataclass +class QSharpApplication(Application[None]): + """Application that produces a resource estimation trace from Q# code. + + Accepts a Q# entry expression string, a callable, or pre-computed + ``LogicalCounts``. + + Attributes: + cache_dir (Path): Directory for caching compiled traces. + use_cache (bool): Whether to use the trace cache. Default is False. + """ + + cache_dir: Path = field( + default=Path.home() / ".cache" / "re3" / "qsharp", repr=False + ) + use_cache: bool = field(default=False, repr=False) + + def __init__(self, entry_expr: str | Callable | LogicalCounts): + """Initialize the Q# application. + + Args: + entry_expr (str | Callable | LogicalCounts): The Q# entry + expression, a callable returning logical counts, or + pre-computed logical counts. + """ + self._entry_expr = entry_expr + + def get_trace(self, parameters: None = None) -> Trace: + """Return the resource estimation trace for the Q# program. + + Args: + parameters (None): Unused. Defaults to None. + + Returns: + Trace: The resource estimation trace. + """ + # TODO: make caching work for `Callable` as well + if self.use_cache and isinstance(self._entry_expr, str): + cache_path = self.cache_dir / f"{self._entry_expr}.json" + else: + cache_path = None + + return trace_from_entry_expr_cached(self._entry_expr, cache_path) diff --git a/source/pip/qsharp/qre/instruction_ids.py b/source/pip/qsharp/qre/instruction_ids.py new file mode 100644 index 0000000000..cec4a9c070 --- /dev/null +++ b/source/pip/qsharp/qre/instruction_ids.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportAttributeAccessIssue=false + + +from .._native import instruction_ids + +for name in instruction_ids.__all__: + globals()[name] = getattr(instruction_ids, name) diff --git a/source/pip/qsharp/qre/instruction_ids.pyi b/source/pip/qsharp/qre/instruction_ids.pyi new file mode 100644 index 0000000000..e8b8c0e739 --- /dev/null +++ b/source/pip/qsharp/qre/instruction_ids.pyi @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Paulis +PAULI_I: int +PAULI_X: int +PAULI_Y: int +PAULI_Z: int + +# Clifford gates +H: int +H_XZ: int +H_XY: int +H_YZ: int +SQRT_X: int +SQRT_X_DAG: int +SQRT_Y: int +SQRT_Y_DAG: int +S: int +SQRT_Z: int +S_DAG: int +SQRT_Z_DAG: int +CNOT: int +CX: int +CY: int +CZ: int +SWAP: int + +# State preparation +PREP_X: int +PREP_Y: int +PREP_Z: int + +# Generic Cliffords +ONE_QUBIT_CLIFFORD: int +TWO_QUBIT_CLIFFORD: int +N_QUBIT_CLIFFORD: int + +# Measurements +MEAS_X: int +MEAS_Y: int +MEAS_Z: int +MEAS_RESET_X: int +MEAS_RESET_Y: int +MEAS_RESET_Z: int +MEAS_XX: int +MEAS_YY: int +MEAS_ZZ: int +MEAS_XZ: int +MEAS_XY: int +MEAS_YZ: int + +# Non-Clifford gates +SQRT_SQRT_X: int +SQRT_SQRT_X_DAG: int +SQRT_SQRT_Y: int +SQRT_SQRT_Y_DAG: int +SQRT_SQRT_Z: int +T: int +SQRT_SQRT_Z_DAG: int +T_DAG: int +CCX: int +CCY: int +CCZ: int +CSWAP: int +AND: int +AND_DAG: int +RX: int +RY: int +RZ: int +CRX: int +CRY: int +CRZ: int +RXX: int +RYY: int +RZZ: int + +# Generic unitary gates +ONE_QUBIT_UNITARY: int +TWO_QUBIT_UNITARY: int + +# Multi-qubit Pauli measurement +MULTI_PAULI_MEAS: int + +# Some generic logical instructions +LATTICE_SURGERY: int + +# Memory/compute operations (used in compute parts of memory-compute layouts) +READ_FROM_MEMORY: int +WRITE_TO_MEMORY: int +MEMORY: int + +# Some special hardware physical instructions +CYCLIC_SHIFT: int + +# Generic operation (for unified RE) +GENERIC: int diff --git a/source/pip/qsharp/qre/interop/__init__.py b/source/pip/qsharp/qre/interop/__init__.py new file mode 100644 index 0000000000..bbf927d3e8 --- /dev/null +++ b/source/pip/qsharp/qre/interop/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._cirq import trace_from_cirq, PushBlock, PopBlock +from ._qsharp import trace_from_entry_expr, trace_from_entry_expr_cached +from ._qir import trace_from_qir + +__all__ = [ + "trace_from_cirq", + "trace_from_entry_expr", + "trace_from_entry_expr_cached", + "trace_from_qir", + "PushBlock", + "PopBlock", +] diff --git a/source/pip/qsharp/qre/interop/_cirq.py b/source/pip/qsharp/qre/interop/_cirq.py new file mode 100644 index 0000000000..fe84dfe4c5 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_cirq.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import random +from dataclasses import dataclass +from math import pi +from typing import Iterable + +import cirq +from cirq import ( + HPowGate, + XPowGate, + YPowGate, + ZPowGate, + CXPowGate, + CZPowGate, + CCXPowGate, + CCZPowGate, + MeasurementGate, + ResetChannel, + GateOperation, + ClassicallyControlledOperation, + PhaseGradientGate, + SwapPowGate, +) +from qsharp.qre import Trace, Block +from qsharp.qre.instruction_ids import ( + H, + PAULI_X, + PAULI_Y, + PAULI_Z, + SQRT_X, + SQRT_X_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + S, + S_DAG, + T, + T_DAG, + CX, + CZ, + RX, + RY, + RZ, + MEAS_Z, + CCX, + CCZ, + SWAP, +) + +_TOLERANCE = 1e-8 + + +def _approx_eq(a: float, b: float) -> bool: + """Check whether two floats are approximately equal.""" + return abs(a - b) <= _TOLERANCE + + +def trace_from_cirq( + circuit: cirq.CIRCUIT_LIKE, *, classical_control_probability: float = 0.5 +) -> Trace: + """Convert a Cirq circuit into a resource estimation Trace. + + Iterates through all moments and operations in the circuit, converting + each gate into trace operations. Gates with a ``_to_trace`` method are + converted directly; others are recursively decomposed via Cirq's + ``_decompose_with_context_`` or ``_decompose_`` protocols. + + Args: + circuit: The Cirq circuit to convert. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. Defaults to 0.5. + + Returns: + A Trace representing the resource profile of the circuit. + """ + + if isinstance(circuit, cirq.Circuit): + # circuit is already in the expected format, so we can process it directly. + pass + elif isinstance(circuit, cirq.Gate): + circuit = cirq.Circuit(circuit.on(*cirq.LineQid.for_gate(circuit))) + else: + # circuit is OP_TREE + circuit = cirq.Circuit(circuit) + + context = _CirqTraceBuilder(circuit, classical_control_probability) + + for moment in circuit: + for op in moment.operations: + context.handle_op(op) + + return context.trace + + +class _CirqTraceBuilder: + """Builds a resource estimation ``Trace`` from a Cirq circuit. + + This class walks the operations produced by ``trace_from_cirq`` and + translates each one into trace instructions. It maintains the state + needed during the conversion: + + * A ``Trace`` instance that accumulates the result. + * A stack of ``Block`` objects so that ``PushBlock`` / ``PopBlock`` + markers can create nested repeated sections. + * A qubit-id mapping (``_QidToTraceId``) that assigns each Cirq qubit + a sequential integer index. + * A Cirq ``DecompositionContext`` for gates that need recursive + decomposition. + + Args: + circuit: The Cirq circuit being converted. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. + """ + + def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): + self._trace = Trace(len(circuit.all_qubits())) + self._classical_control_probability = classical_control_probability + self._blocks = [self._trace.root_block()] + self._q_to_id = _QidToTraceId(circuit.all_qubits()) + self._decomp_context = cirq.DecompositionContext( + qubit_manager=cirq.GreedyQubitManager("trace_from_cirq") + ) + + def push_block(self, repetitions: int): + """Open a new repeated block with the given number of repetitions.""" + block = self.block.add_block(repetitions) + self._blocks.append(block) + + def pop_block(self): + """Close the current repeated block, returning to the parent.""" + self._blocks.pop() + + @property + def trace(self) -> Trace: + """The accumulated trace, with ``compute_qubits`` updated to reflect + all qubits seen so far (including any allocated during decomposition).""" + self._trace.compute_qubits = len(self._q_to_id) + return self._trace + + @property + def block(self) -> Block: + """The innermost open block in the trace.""" + return self._blocks[-1] + + @property + def q_to_id(self) -> _QidToTraceId: + """Mapping from Cirq ``Qid`` to integer trace qubit index.""" + return self._q_to_id + + @property + def classical_control_probability(self) -> float: + """Probability used to stochastically include classically controlled + operations.""" + return self._classical_control_probability + + @property + def decomp_context(self) -> cirq.DecompositionContext: + """Cirq decomposition context shared across all recursive + decompositions.""" + return self._decomp_context + + def handle_op( + self, + op: cirq.OP_TREE | TraceGate | PushBlock | PopBlock, + ) -> None: + """Recursively convert a single operation into trace instructions. + + Supported operation forms: + + - ``TraceGate``: A raw trace instruction, added directly to the + current block. + - ``PushBlock`` / ``PopBlock``: Control block nesting with + repetitions. + - ``GateOperation``: Dispatched via ``_to_trace`` if available on + the gate, otherwise decomposed via + ``_decompose_with_context_`` or ``_decompose_``. + - ``ClassicallyControlledOperation``: Included with the probability + given by ``classical_control_probability``. + - ``list`` / iterable: Each element is handled recursively. + - Any other ``cirq.Operation``: Decomposed via + ``_decompose_with_context_``. + + Args: + op: The operation to convert. + """ + if isinstance(op, TraceGate): + qs = [ + self.q_to_id[q] + for q in ([op.qubits] if isinstance(op.qubits, cirq.Qid) else op.qubits) + ] + + if op.params is None: + self.block.add_operation(op.id, qs) + else: + self.block.add_operation( + op.id, qs, op.params if isinstance(op.params, list) else [op.params] + ) + elif isinstance(op, PushBlock): + self.push_block(op.repetitions) + elif isinstance(op, PopBlock): + self.pop_block() + elif isinstance(op, cirq.Operation): + if isinstance(op, GateOperation): + gate = op.gate + + if hasattr(gate, "_to_trace"): + for sub_op in gate._to_trace(self.decomp_context, op): # type: ignore + self.handle_op(sub_op) + elif hasattr(gate, "_decompose_with_context_"): + for sub_op in gate._decompose_with_context_(op.qubits, self.decomp_context): # type: ignore + self.handle_op(sub_op) + elif hasattr(gate, "_decompose_"): + # decompose the gate and handle the resulting operations recursively + for sub_op in gate._decompose_(op.qubits): # type: ignore + self.handle_op(sub_op) + else: + for sub_op in op._decompose_with_context_(self.decomp_context): # type: ignore + self.handle_op(sub_op) + elif isinstance(op, ClassicallyControlledOperation): + if random.random() < self.classical_control_probability: + self.handle_op(op.without_classical_controls()) + else: + for sub_op in op._decompose_with_context_(self.decomp_context): # type: ignore + self.handle_op(sub_op) + else: + # op is Iterable[OP_TREE] + for sub_op in op: + self.handle_op(sub_op) + + +@dataclass(frozen=True, slots=True) +class PushBlock: + """Signals the start of a repeated block in the trace. + + Args: + repetitions: Number of times the block is repeated. + """ + + repetitions: int + + +@dataclass(frozen=True, slots=True) +class PopBlock: + """Signals the end of the current repeated block in the trace.""" + + ... + + +@dataclass(frozen=True, slots=True) +class TraceGate: + """A raw trace instruction emitted during Cirq circuit conversion. + + Attributes: + id (int): The instruction ID. + qubits (list[cirq.Qid] | cirq.Qid): The target qubits. + params (list[float] | float | None): Optional gate parameters. + """ + + id: int + qubits: list[cirq.Qid] | cirq.Qid + params: list[float] | float | None = None + + +class _QidToTraceId(dict): + """Mapping from Cirq qubits to integer trace qubit indices. + + Initialized with a set of known qubits. If an unknown qubit is looked + up, it is automatically assigned the next available index. + """ + + def __init__(self, init: Iterable[cirq.Qid]): + super().__init__({q: i for i, q in enumerate(init)}) + + def __getitem__(self, key: cirq.Qid) -> int: + """ + If the key is not present, add it to the mapping with the next available id. + """ + + if key not in self: + self[key] = len(self) + return super().__getitem__(key) + + +def h_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert an HPowGate into trace instructions.""" + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(H, [op.qubits[0]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def x_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert an XPowGate into trace instructions.""" + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_X, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(SQRT_X, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(SQRT_X_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(SQRT_SQRT_X, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(SQRT_SQRT_X_DAG, q) + else: + yield TraceGate(RX, q, exp * pi) + + +def y_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a YPowGate into trace instructions.""" + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_Y, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(SQRT_Y, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(SQRT_Y_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(SQRT_SQRT_Y, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(SQRT_SQRT_Y_DAG, q) + else: + yield TraceGate(RY, q, exp * pi) + + +def z_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a ZPowGate into trace instructions.""" + q = [op.qubits[0]] + exp = self.exponent + if _approx_eq(exp, 1) or _approx_eq(exp, -1): + yield TraceGate(PAULI_Z, q) + elif _approx_eq(exp, 0.5): + yield TraceGate(S, q) + elif _approx_eq(exp, -0.5): + yield TraceGate(S_DAG, q) + elif _approx_eq(exp, 0.25): + yield TraceGate(T, q) + elif _approx_eq(exp, -0.25): + yield TraceGate(T_DAG, q) + else: + yield TraceGate(RZ, q, exp * pi) + + +def cx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CXPowGate into trace instructions.""" + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CX, [op.qubits[0], op.qubits[1]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def cz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CZPowGate into trace instructions.""" + exp = self.exponent + c, t = op.qubits[0], op.qubits[1] + if _approx_eq(abs(exp), 1): + yield TraceGate(CZ, [c, t]) + elif _approx_eq(exp, 0.5): + # controlled S gate + yield TraceGate(T, [c]) + yield TraceGate(T, [t]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(T_DAG, [t]) + yield TraceGate(CZ, [c, t]) + elif _approx_eq(exp, -0.5): + # controlled S† gate + yield TraceGate(T_DAG, [c]) + yield TraceGate(T_DAG, [t]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(T, [t]) + yield TraceGate(CZ, [c, t]) + else: + rads = exp / 2 * pi + yield TraceGate(RZ, [c], [rads]) + yield TraceGate(RZ, [t], [rads]) + yield TraceGate(CZ, [c, t]) + yield TraceGate(RZ, [t], [-rads]) + yield TraceGate(CZ, [c, t]) + + +def swap_pow_gate_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + """Convert a SwapPowGate into trace instructions.""" + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(SWAP, [op.qubits[0], op.qubits[1]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def ccx_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CCXPowGate into trace instructions.""" + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CCX, [op.qubits[0], op.qubits[1], op.qubits[2]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def ccz_pow_gate_to_trace(self, context: cirq.DecompositionContext, op: cirq.Operation): + """Convert a CCZPowGate into trace instructions.""" + if _approx_eq(abs(self.exponent), 1): + yield TraceGate(CCZ, [op.qubits[0], op.qubits[1], op.qubits[2]]) + else: + yield from op._decompose_with_context_(context) # type: ignore + + +def measurement_gate_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + """Convert a MeasurementGate into trace instructions.""" + for q in op.qubits: + yield TraceGate(MEAS_Z, [q]) + + +def reset_channel_to_trace( + self, context: cirq.DecompositionContext, op: cirq.Operation +): + """Convert a ResetChannel into trace instructions (no-op).""" + yield from () + + +# Attach _to_trace methods to Cirq gate classes so that handle_op can +# convert them directly into trace instructions without decomposition. +HPowGate._to_trace = h_pow_gate_to_trace +XPowGate._to_trace = x_pow_gate_to_trace +YPowGate._to_trace = y_pow_gate_to_trace +ZPowGate._to_trace = z_pow_gate_to_trace +CXPowGate._to_trace = cx_pow_gate_to_trace +CZPowGate._to_trace = cz_pow_gate_to_trace +SwapPowGate._to_trace = swap_pow_gate_to_trace +CCXPowGate._to_trace = ccx_pow_gate_to_trace +CCZPowGate._to_trace = ccz_pow_gate_to_trace +MeasurementGate._to_trace = measurement_gate_to_trace +ResetChannel._to_trace = reset_channel_to_trace + +# Decomposition overrides + + +def phase_gradient_decompose(self, qubits): + """Override PhaseGradientGate._decompose_ to skip rotations with very small angles. + + The original implementation may lead to floating-point overflows for + large values of i. + """ + + for i, q in enumerate(qubits): + exp = self.exponent / 2**i + if exp < 1e-16: + break + yield cirq.Z(q) ** exp + + +PhaseGradientGate._decompose_ = phase_gradient_decompose diff --git a/source/pip/qsharp/qre/interop/_qir.py b/source/pip/qsharp/qre/interop/_qir.py new file mode 100644 index 0000000000..ebfb9559d1 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_qir.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import pyqir + +from ..._native import QirInstructionId +from ..._simulation import AggregateGatesPass +from .. import instruction_ids as ids +from .._qre import Trace + +# Maps QirInstructionId to (instruction_id, arity) where arity is: +# 1 = single-qubit gate: tuple is (op, qubit) +# 2 = two-qubit gate: tuple is (op, qubit1, qubit2) +# 3 = three-qubit gate: tuple is (op, qubit1, qubit2, qubit3) +# -1 = single-qubit rotation: tuple is (op, angle, qubit) +# -2 = two-qubit rotation: tuple is (op, angle, qubit1, qubit2) +_GATE_MAP: list[tuple[QirInstructionId, int, int]] = [ + # Single-qubit gates + (QirInstructionId.I, ids.PAULI_I, 1), + (QirInstructionId.H, ids.H, 1), + (QirInstructionId.X, ids.PAULI_X, 1), + (QirInstructionId.Y, ids.PAULI_Y, 1), + (QirInstructionId.Z, ids.PAULI_Z, 1), + (QirInstructionId.S, ids.S, 1), + (QirInstructionId.SAdj, ids.S_DAG, 1), + (QirInstructionId.SX, ids.SQRT_X, 1), + (QirInstructionId.SXAdj, ids.SQRT_X_DAG, 1), + (QirInstructionId.T, ids.T, 1), + (QirInstructionId.TAdj, ids.T_DAG, 1), + # Two-qubit gates + (QirInstructionId.CNOT, ids.CNOT, 2), + (QirInstructionId.CX, ids.CX, 2), + (QirInstructionId.CY, ids.CY, 2), + (QirInstructionId.CZ, ids.CZ, 2), + (QirInstructionId.SWAP, ids.SWAP, 2), + # Three-qubit gates + (QirInstructionId.CCX, ids.CCX, 3), + # Single-qubit rotations (op, angle, qubit) + (QirInstructionId.RX, ids.RX, -1), + (QirInstructionId.RY, ids.RY, -1), + (QirInstructionId.RZ, ids.RZ, -1), + # Two-qubit rotations (op, angle, qubit1, qubit2) + (QirInstructionId.RXX, ids.RXX, -2), + (QirInstructionId.RYY, ids.RYY, -2), + (QirInstructionId.RZZ, ids.RZZ, -2), +] + +_MEAS_MAP: list[tuple[QirInstructionId, int]] = [ + (QirInstructionId.M, ids.MEAS_Z), + (QirInstructionId.MZ, ids.MEAS_Z), + (QirInstructionId.MResetZ, ids.MEAS_RESET_Z), +] + +_SKIP = ( + # Resets qubit to |0⟩ without measuring; we do not currently account for + # that in resource estimation + QirInstructionId.RESET, + # Runtime qubit state transfer; an implementation detail, not a logical operation + QirInstructionId.Move, + # Reads a measurement result from classical memory; purely classical I/O + QirInstructionId.ReadResult, + # The following are classical output recording operations that do not represent + # quantum operations and have no impact on resource estimation. + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, +) + + +def trace_from_qir(input: str | bytes) -> Trace: + """Convert a QIR program into a resource-estimation Trace. + + Parses the QIR module, extracts quantum gates, and builds a Trace that + can be used for resource estimation. Conditional branches are resolved + by always following the false path (assuming measurement results are Zero). + + Args: + input: QIR input as LLVM IR text (str) or bitcode (bytes). + + Returns: + A Trace containing the quantum operations from the QIR program. + """ + context = pyqir.Context() + + if isinstance(input, str): + mod = pyqir.Module.from_ir(context, input) + else: + mod = pyqir.Module.from_bitcode(context, input) + + gates, num_qubits, _ = AggregateGatesPass().run(mod) + + trace = Trace(compute_qubits=num_qubits) + + for gate in gates: + # NOTE: AggregateGatesPass does not return QirInstruction objects + assert isinstance(gate, tuple) + _add_gate(trace, gate) + + return trace + + +def _add_gate(trace: Trace, gate: tuple) -> None: + """Add a single QIR gate tuple to the trace.""" + op = gate[0] + + for qir_id, instr_id, arity in _GATE_MAP: + if op == qir_id: + if arity == 1: + trace.add_operation(instr_id, [gate[1]]) + elif arity == 2: + trace.add_operation(instr_id, [gate[1], gate[2]]) + elif arity == 3: + trace.add_operation(instr_id, [gate[1], gate[2], gate[3]]) + elif arity == -1: + trace.add_operation(instr_id, [gate[2]], [gate[1]]) + elif arity == -2: + trace.add_operation(instr_id, [gate[2], gate[3]], [gate[1]]) + return + + for qir_id, instr_id in _MEAS_MAP: + if op == qir_id: + trace.add_operation(instr_id, [gate[1]]) + return + + for skip_id in _SKIP: + if op == skip_id: + return + + # The only unhandled QirInstructionId is CorrelatedNoise + assert op == QirInstructionId.CorrelatedNoise, f"Unexpected QIR instruction: {op}" + raise NotImplementedError(f"Unsupported QIR instruction: {op}") diff --git a/source/pip/qsharp/qre/interop/_qsharp.py b/source/pip/qsharp/qre/interop/_qsharp.py new file mode 100644 index 0000000000..b7c31767f0 --- /dev/null +++ b/source/pip/qsharp/qre/interop/_qsharp.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path +import time +from typing import Callable, Optional + +from ..._qsharp import logical_counts +from ...estimator import LogicalCounts +from .._qre import Trace +from ..instruction_ids import CCX, MEAS_Z, RZ, T, READ_FROM_MEMORY, WRITE_TO_MEMORY +from ..property_keys import ( + EVALUATION_TIME, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, +) + + +def _bucketize_rotation_counts( + rotation_count: int, rotation_depth: int +) -> list[tuple[int, int]]: + """ + Return a list of (count, depth) pairs representing the rotation layers in + the trace. + + The following properties hold for the returned list ``result``: + - sum(depth for _, depth in result) == rotation_depth + - sum(count * depth for count, depth in result) == rotation_count + - count > 0 for each (count, _) in result + - count <= qubit_count for each (count, _) in result holds by definition + when rotation_count <= rotation_depth * qubit_count + + Args: + rotation_count: Total number of rotations. + rotation_depth: Total depth of the rotation layers. + + Returns: + A list of (count, depth) pairs, where 'count' is the number of + rotations in a layer and 'depth' is the depth of that layer. + """ + if rotation_depth == 0: + return [] + + base = rotation_count // rotation_depth + extra = rotation_count % rotation_depth + + result: list[tuple[int, int]] = [] + if extra > 0: + result.append((base + 1, extra)) + if rotation_depth - extra > 0: + result.append((base, rotation_depth - extra)) + return result + + +def trace_from_entry_expr(entry_expr: str | Callable | LogicalCounts) -> Trace: + """Convert a Q# entry expression into a resource-estimation Trace. + + Evaluates the entry expression to obtain logical counts, then builds + a trace containing the corresponding quantum operations. + + Args: + entry_expr (str | Callable | LogicalCounts): A Q# entry expression + string, a callable, or pre-computed logical counts. + + Returns: + Trace: A trace representing the resource profile of the program. + """ + + start = time.time_ns() + counts = ( + logical_counts(entry_expr) + if not isinstance(entry_expr, LogicalCounts) + else entry_expr + ) + evaluation_time = time.time_ns() - start + + ccx_count = counts.get("cczCount", 0) + counts.get("ccixCount", 0) + + # Q# logical counts report total number of qubits (compute + memory) + num_qubits = counts.get("numQubits", 0) + # Compute qubits may be reported separately + compute_qubits = counts.get("numComputeQubits", num_qubits) + memory_qubits = num_qubits - compute_qubits + + trace = Trace(compute_qubits) + + rotation_count = counts.get("rotationCount", 0) + rotation_depth = counts.get("rotationDepth", rotation_count) + + if rotation_count != 0 and rotation_depth != 0: + for count, depth in _bucketize_rotation_counts(rotation_count, rotation_depth): + block = trace.add_block(repetitions=depth) + for i in range(count): + block.add_operation(RZ, [i]) + + if t_count := counts.get("tCount", 0): + block = trace.add_block(repetitions=t_count) + block.add_operation(T, [0]) + + if ccx_count: + block = trace.add_block(repetitions=ccx_count) + block.add_operation(CCX, [0, 1, 2]) + + if meas_count := counts.get("measurementCount", 0): + block = trace.add_block(repetitions=meas_count) + block.add_operation(MEAS_Z, [0]) + + if memory_qubits != 0: + trace.memory_qubits = memory_qubits + + if rfm_count := counts.get("readFromMemoryCount", 0): + block = trace.add_block(repetitions=rfm_count) + block.add_operation(READ_FROM_MEMORY, [0, compute_qubits]) + + if wtm_count := counts.get("writeToMemoryCount", 0): + block = trace.add_block(repetitions=wtm_count) + block.add_operation(WRITE_TO_MEMORY, [0, compute_qubits]) + + trace.set_property(EVALUATION_TIME, evaluation_time) + trace.set_property(ALGORITHM_COMPUTE_QUBITS, compute_qubits) + trace.set_property(ALGORITHM_MEMORY_QUBITS, memory_qubits) + return trace + + +def trace_from_entry_expr_cached( + entry_expr: str | Callable | LogicalCounts, cache_path: Optional[Path] +) -> Trace: + """Convert a Q# entry expression into a Trace, with optional caching. + + If *cache_path* is provided and exists, the trace is loaded from disk. + Otherwise, the trace is computed via ``trace_from_entry_expr`` and + optionally written to *cache_path*. + + Args: + entry_expr (str | Callable | LogicalCounts): A Q# entry expression + string, a callable, or pre-computed logical counts. + cache_path (Optional[Path]): Path for reading/writing the cached + trace. If None, caching is disabled. + + Returns: + Trace: A trace representing the resource profile of the program. + """ + if cache_path and cache_path.exists(): + return Trace.from_json(cache_path.read_text()) + + trace = trace_from_entry_expr(entry_expr) + + if cache_path: + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_text(trace.to_json()) + + return trace diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py new file mode 100644 index 0000000000..3da76797ac --- /dev/null +++ b/source/pip/qsharp/qre/models/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .factories import Litinski19Factory, MagicUpToClifford, RoundBasedFactory +from .qec import ( + SurfaceCode, + ThreeAux, + OneDimensionalYokedSurfaceCode, + TwoDimensionalYokedSurfaceCode, +) +from .qubits import GateBased, Majorana + +__all__ = [ + "GateBased", + "Litinski19Factory", + "Majorana", + "MagicUpToClifford", + "RoundBasedFactory", + "SurfaceCode", + "ThreeAux", + "OneDimensionalYokedSurfaceCode", + "TwoDimensionalYokedSurfaceCode", +] diff --git a/source/pip/qsharp/qre/models/factories/__init__.py b/source/pip/qsharp/qre/models/factories/__init__.py new file mode 100644 index 0000000000..e652dfc983 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._litinski import Litinski19Factory +from ._round_based import RoundBasedFactory +from ._utils import MagicUpToClifford + +__all__ = ["Litinski19Factory", "MagicUpToClifford", "RoundBasedFactory"] diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py new file mode 100644 index 0000000000..ffe4b2558d --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -0,0 +1,395 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass +from math import ceil +from typing import Generator + +from ..._architecture import ISAContext +from ..._qre import ISARequirements, ConstraintBound, ISA +from ..._instruction import ISATransform, constraint, LOGICAL +from ...instruction_ids import T, CNOT, H, MEAS_Z, CCZ + + +@dataclass +class Litinski19Factory(ISATransform): + """ + T and CCZ factories based on the paper + [arXiv:1905.06903](https://arxiv.org/abs/1905.06903). + + It contains two categories of estimates. If the input T error rate is + similar to the Clifford error, it produces magic state instructions based on + Table 1 in the paper. If the input T error rate is at most 10 times higher + than the Clifford error rate, it produces magic state instructions based on + Table 2 in the paper. + + It requires Clifford error rates of at most 0.1% for CNOT, H, and MEAS_Z + instructions. If these instructions have different error rates, the maximum + error rate is assumed. + + References: + + - Daniel Litinski: Magic state distillation: not as costly as you think, + [arXiv:1905.06903](https://arxiv.org/abs/1905.06903) + """ + + def __post_init__(self): + self._initialize_entries() + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + # T error rate may be at least 10x higher than Clifford error rates + constraint(T, error_rate=ConstraintBound.le(1e-2)), + constraint(H, error_rate=ConstraintBound.le(1e-3)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.le(1e-3)), + constraint(MEAS_Z, error_rate=ConstraintBound.le(1e-3)), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + """Yield ISAs with T and CCZ factory instructions. + + Args: + impl_isa (ISA): The implementation ISA providing physical gates. + ctx (ISAContext): The enumeration context. + + Yields: + ISA: An ISA containing distilled T and/or CCZ instructions. + """ + h = impl_isa[H] + cnot = impl_isa[CNOT] + meas_z = impl_isa[MEAS_Z] + t = impl_isa[T] + + clifford_error_rate = max( + h.expect_error_rate(), + cnot.expect_error_rate(), + meas_z.expect_error_rate(), + ) + + t_error_rate = t.expect_error_rate() + + entries_by_state = None + + if clifford_error_rate <= 1e-4: + if t_error_rate <= 1e-4: + entries_by_state = self._entries[1e-4][0] + elif t_error_rate <= 1e-3: + entries_by_state = self._entries[1e-4][1] + else: + # NOTE: This assertion is valid due to the constraint bound in the + # required_isa method + assert clifford_error_rate <= 1e-3 + if t_error_rate <= 1e-3: + entries_by_state = self._entries[1e-3][0] + elif t_error_rate <= 1e-2: + entries_by_state = self._entries[1e-3][1] + + if entries_by_state is None: + return + + t_entries = entries_by_state.get(T, []) + ccz_entries = entries_by_state.get(CCZ, []) + + syndrome_extraction_time = ( + 4 * impl_isa[CNOT].expect_time() + + impl_isa[H].expect_time() + + impl_isa[MEAS_Z].expect_time() + ) + + def make_node(entry: _Entry) -> int: + # Convert cycles (number of syndrome extraction cycles) to time + # based on fast surface code + time = ceil(syndrome_extraction_time * entry.cycles) + + # NOTE: If the protocol outputs multiple states, we assume that the + # space cost is divided by the number of output states. This is a + # simplification that allows us to fit all protocols in the ISA, but + # it may not be accurate for all protocols. + return ctx.add_instruction( + entry.state, + arity=3 if entry.state == CCZ else 1, + encoding=LOGICAL, + space=ceil(entry.space / entry.output_states), + time=time, + error_rate=entry.error_rate, + transform=self, + source=[cnot, h, meas_z, t], + ) + + # Yield combinations of T and CCZ entries + if ccz_entries: + for t_entry in t_entries: + for ccz_entry in ccz_entries: + yield ctx.make_isa( + make_node(t_entry), + make_node(ccz_entry), + ) + else: + # Table 2 scenarios: only T gates available + for t_entry in t_entries: + yield ctx.make_isa(make_node(t_entry)) + + def _initialize_entries(self): + """Initialize the distillation protocol lookup tables.""" + self._entries = { + # Assuming a Clifford error rate of at most 1e-4: + 1e-4: ( + # Assuming a T error rate of at most 1e-4 (Table 1): + { + T: [ + _Entry(_Protocol(15, 1, 7, 3, 3), 4.4e-8, 810, 18.1), + _Entry(_Protocol(15, 1, 9, 3, 3), 9.3e-10, 1_150, 18.1), + _Entry(_Protocol(15, 1, 11, 5, 5), 1.9e-11, 2_070, 30.0), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(20, 4, 15, 7, 9), 1), + ], + 2.4e-15, + 16_400, + 90.3, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(15, 1, 25, 9, 9), 1), + ], + 6.3e-25, + 18_600, + 67.8, + ), + _Entry(_Protocol(15, 1, 9, 3, 3), 1.5e-9, 762, 36.2), + ], + CCZ: [ + _Entry( + [ + (_Protocol(15, 1, 7, 3, 3), 4), + (_Protocol(8, 1, 15, 7, 9, CCZ), 1), + ], + 7.2e-14, + 12_400, + 36.1, + ), + ], + }, + # Assuming a T error rate of at most 1e-3 (10x higher than Clifford, Table 2): + { + T: [ + _Entry(_Protocol(15, 1, 9, 3, 3), 2.1e-8, 1_150, 18.2), + _Entry( + [ + (_Protocol(15, 1, 7, 3, 3), 6), + (_Protocol(20, 4, 13, 5, 7), 1), + ], + 1.4e-12, + 13_200, + 70, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(20, 4, 15, 7, 9), 1), + ], + 6.6e-15, + 16_400, + 91.2, + ), + _Entry( + [ + (_Protocol(15, 1, 9, 3, 3), 4), + (_Protocol(15, 1, 25, 9, 9), 1), + ], + 4.2e-22, + 18_600, + 68.4, + ), + ], + CCZ: [], + }, + ), + # Assuming a Clifford error rate of at most 1e-3: + 1e-3: ( + # Assuming a T error rate of at most 1e-3 (Table 1): + { + T: [ + _Entry(_Protocol(15, 1, 17, 7, 7), 4.5e-8, 4_620, 42.6), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(20, 4, 23, 11, 13), 1), + ], + 1.4e-10, + 43_300, + 130, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 4), + (_Protocol(20, 4, 27, 13, 15), 1), + ], + 2.6e-11, + 46_800, + 157, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 25, 11, 11), 1), + ], + 2.7e-12, + 30_700, + 82.5, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(15, 1, 29, 11, 13), 1), + ], + 3.3e-14, + 39_100, + 97.5, + ), + _Entry( + [ + (_Protocol(15, 1, 15, 7, 7), 6), + (_Protocol(15, 1, 41, 17, 17), 1), + ], + 4.5e-20, + 73_400, + 128, + ), + ], + CCZ: [ + _Entry( + [ + (_Protocol(15, 1, 13, 7, 7), 6), + (_Protocol(8, 1, 25, 15, 15, CCZ), 1), + ], + 5.2e-11, + 47_000, + 60, + ), + ], + }, + # Assuming a T error rate of at most 1e-2 (10x higher than Clifford, Table 2): + { + T: [ + _Entry( + [ + (_Protocol(15, 1, 13, 5, 5), 6), + (_Protocol(20, 4, 21, 11, 13), 1), + ], + 5.7e-9, + 40_700, + 130, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 21, 9, 11), 1), + ], + 2.1e-10, + 27_400, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 23, 11, 11), 1), + ], + 2.5e-11, + 29_500, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 11, 5, 5), 6), + (_Protocol(15, 1, 25, 11, 11), 1), + ], + 6.4e-12, + 30_700, + 85.7, + ), + _Entry( + [ + (_Protocol(15, 1, 13, 7, 7), 8), + (_Protocol(15, 1, 29, 13, 13), 1), + ], + 1.5e-13, + 52_400, + 97.5, + ), + ], + CCZ: [], + }, + ), + } + + +@dataclass(frozen=True, slots=True) +class _Entry: + """A single distillation protocol entry from the Litinski tables. + + Attributes: + protocol (list[tuple[_Protocol, int]] | _Protocol): The distillation + protocol or pipeline of protocols. + error_rate (float): Output error rate of the protocol. + space (int): Space cost in physical qubits. + cycles (float): Number of syndrome extraction cycles. + """ + + protocol: list[tuple[_Protocol, int]] | _Protocol + error_rate: float + # Space estimation in number of physical qubits + space: int + # Number of code cycles to estimate time; a code cycle corresponds to + # measuring all surface-code check operators exactly once. + cycles: float + + @property + def output_states(self) -> int: + """Return the number of output magic states.""" + if isinstance(self.protocol, list): + return self.protocol[-1][0].output_states + else: + return self.protocol.output_states + + @property + def state(self) -> int: + """Return the magic state instruction ID (T or CCZ).""" + if isinstance(self.protocol, list): + return self.protocol[-1][0].state + else: + return self.protocol.state + + +@dataclass(frozen=True, slots=True) +class _Protocol: + """Parameters for a single distillation protocol. + + Attributes: + input_states (int): Number of input T states. + output_states (int): Number of output T states. + d_x (int): Spatial X distance. + d_z (int): Spatial Z distance. + d_m (int): Temporal distance. + state (int): Magic state instruction ID. Default is T. + """ + + # Number of input T states in protocol + input_states: int + # Number of output T states in protocol + output_states: int + # Spatial X distance (arXiv:1905.06903, Section 2) + d_x: int + # Spatial Z distance (arXiv:1905.06903, Section 2) + d_z: int + # Temporal distance (arXiv:1905.06903, Section 2) + d_m: int + # Magic state + state: int = T diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py new file mode 100644 index 0000000000..c4b9379448 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -0,0 +1,451 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from itertools import combinations_with_replacement +from math import ceil +from pathlib import Path +from typing import Callable, Generator, Iterable, Optional, Sequence + +from ..._qre import ISA, InstructionFrontier, ISARequirements, Instruction, _binom_ppf +from ..._instruction import ( + LOGICAL, + PHYSICAL, + ISAQuery, + ISATransform, + constraint, +) +from ..._architecture import ISAContext +from ...instruction_ids import CNOT, LATTICE_SURGERY, T, MEAS_ZZ +from ..qec import SurfaceCode + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RoundBasedFactory(ISATransform): + """ + A magic state factory that produces T gate instructions using round-based + distillation pipelines. + + This factory explores combinations of distillation units (such as "15-to-1 + RM prep" and "15-to-1 space efficient") to find optimal configurations that + minimize time and space while achieving target error rates. It supports + both physical-level distillation (when the input T gate is physically + encoded) and logical-level distillation (using lattice surgery via surface + codes). + + In order to account for the success probability of distillation rounds, the + factory models the pipeline using a failure probability requirement + (defaulting to 1%) that each round must meet. The number of distillation + units per round is adjusted to meet this requirement, which in turn affects + the overall space requirements. + + Space requirements are calculated using a user-provided function that + aggregates per-round space (e.g., sum or max). The ``sum`` function models + the case in which qubits are not reused across rounds, while the ``max`` + function models the case in which qubits are reused across rounds. + + For the enumeration of logical-level distillation units, the factory relies + on a user-provided ``ISAQuery`` (defaulting to ``SurfaceCode.q()``) to explore + different surface code configurations and their corresponding lattice + surgery instructions. These need to be provided by the user and cannot + automatically be derived from the provided implementation ISA, as they can + only contain a subset of the required instructions. The user needs to + ensure that the provided query matches the architecture for which this + factory is being used. + + Results are cached to disk for efficiency. + + Attributes: + code_query: ISAQuery + Query to enumerate QEC codes for logical distillation units. + Defaults to SurfaceCode.q(). + physical_qubit_calculation: Callable[[Iterable], int] + Function to calculate total physical qubits from per-round space + requirements, e.g., sum or max. Defaults to sum. + cache_dir: Path + Directory for caching computed factory configurations. Defaults to + ~/.cache/re3/round_based. + use_cache: bool + Whether to use cached results. Defaults to True. + + References: + + - Sergei Bravyi, Alexei Kitaev: Universal Quantum Computation with ideal + Clifford gates and noisy ancillas, + [arXiv:quant-ph/0403025](https://arxiv.org/abs/quant-ph/0403025) + - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, + Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi + Sundaram, Alexander Vaschillo: Assessing requirements to scale to + practical quantum advantage, + [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + """ + + code_query: ISAQuery = field(default_factory=lambda: SurfaceCode.q()) + physical_qubit_calculation: Callable[[Iterable], int] = field(default=sum) + # optional: make cache directory configurable + cache_dir: Path = field( + default=Path.home() / ".cache" / "re3" / "round_based", repr=False + ) + use_cache: bool = field(default=True, repr=False) + + @staticmethod + def required_isa() -> ISARequirements: + # NOTE: A T gate is required, but a CNOT is only required to explore + # physical units. + return ISARequirements( + constraint(T), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + cache_path = self._cache_path(impl_isa) + + # 1) Try to load from cache + if self.use_cache and cache_path.exists(): + cached_states = InstructionFrontier.load(str(cache_path)) + for state in cached_states: + yield ctx.make_isa( + ctx.add_instruction(state, transform=self, source=[impl_isa[T]]) + ) + return + + # 2) Compute as before + t_gate_error = impl_isa[T].expect_error_rate() + + units: list[_DistillationUnit] = [] + initial_unit = [] + + # Physical units? + if impl_isa[T].encoding == PHYSICAL: + clifford_gate = impl_isa.get(CNOT) or impl_isa.get(MEAS_ZZ) + if clifford_gate is None: + raise ValueError( + "CNOT or MEAS_ZZ instruction is required for physical units" + ) + + gate_time = clifford_gate.expect_time() + clifford_error = clifford_gate.expect_error_rate() + units.extend(self._physical_units(gate_time, clifford_error)) + else: + initial_unit.append( + _DistillationUnit( + 1, + impl_isa[T].expect_time(), + impl_isa[T].expect_space(), + [1, 0], + [0], + ) + ) + + for code_isa in self.code_query.enumerate(ctx): + units.extend(self._logical_units(code_isa[LATTICE_SURGERY])) + + optimal_states = InstructionFrontier() + + for r in range(1, 4 - len(initial_unit)): + for k in combinations_with_replacement(units, r): + pipeline = _Pipeline.try_create( + initial_unit + list(k), + t_gate_error, + physical_qubit_calculation=self.physical_qubit_calculation, + ) + if pipeline is not None: + state = self._state_from_pipeline(pipeline) + optimal_states.insert(state) + logger.debug(f"Optimal states after {r} rounds: {len(optimal_states)}") + + # 3) Save to cache, then yield + if self.use_cache: + optimal_states.dump(str(cache_path)) + + for state in optimal_states: + yield ctx.make_isa( + ctx.add_instruction(state, transform=self, source=[impl_isa[T]]) + ) + + def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: + """Return physical distillation units for the given gate parameters.""" + return [ + _DistillationUnit( + num_input_states=15, + time=24 * gate_time, + space=31, + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * clifford_error], + failure_probability_coeffs=[15, 356 * clifford_error], + name="15-to-1 RM prep", + ), + _DistillationUnit( + num_input_states=15, + time=45 * gate_time, + space=12, + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * clifford_error], + failure_probability_coeffs=[15, 356 * clifford_error], + name="15-to-1 space efficient", + ), + ] + + def _logical_units( + self, lattice_surgery_instruction: Instruction + ) -> list[_DistillationUnit]: + """Return logical distillation units derived from a lattice surgery instruction.""" + logical_cycle_time = lattice_surgery_instruction.expect_time(1) + logical_error = lattice_surgery_instruction.expect_error_rate(1) + + return [ + _DistillationUnit( + num_input_states=15, + time=11 * logical_cycle_time, + space=lattice_surgery_instruction.expect_space(31), + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * logical_error], + failure_probability_coeffs=[15, 356 * logical_error], + name="15-to-1 RM prep", + ), + _DistillationUnit( + num_input_states=15, + time=13 * logical_cycle_time, + space=lattice_surgery_instruction.expect_space(20), + error_rate_coeffs=[35, 0.0, 0.0, 7.1 * logical_error], + failure_probability_coeffs=[15, 356 * logical_error], + name="15-to-1 space efficient", + ), + ] + + def _state_from_pipeline(self, pipeline: _Pipeline) -> Instruction: + """Create a T-gate instruction from a distillation pipeline.""" + return Instruction.fixed_arity( + T, + int(LOGICAL), + 1, + pipeline.time, + pipeline.space, + None, + pipeline.error_rate, + ) + + def _cache_key(self, impl_isa: ISA) -> str: + """Build a deterministic key from factory configuration and impl_isa.""" + parts = [ + f"factory={type(self).__qualname__}", + f"code_query={repr(self.code_query)}", + f"physical_qubit_calculation={self.physical_qubit_calculation.__name__}", + ] + + # Include full instruction details, sorted by id for determinism + for instr in sorted(impl_isa, key=lambda i: i.id): + parts.append( + f"id={instr.id}|encoding={instr.encoding}|arity={instr.arity}" + f"|time={instr.time()}|space={instr.space()}" + f"|error_rate={instr.error_rate()}" + ) + + data = "\n".join(parts).encode("utf-8") + return hashlib.sha256(data).hexdigest() + + def _cache_path(self, impl_isa: ISA) -> Path: + """Return the cache file path for the given implementation ISA.""" + self.cache_dir.mkdir(parents=True, exist_ok=True) + return self.cache_dir / f"{self._cache_key(impl_isa)}.json" + + +class _Pipeline: + """A multi-round distillation pipeline.""" + + def __init__( + self, + units: Sequence[_DistillationUnit], + initial_input_error_rate: float, + *, + failure_probability_requirement: float = 0.01, + physical_qubit_calculation: Callable[[Iterable], int] = sum, + ): + self.failure_probability_requirement = failure_probability_requirement + self.rounds: list["_DistillationRound"] = [] + self.output_error_rate: float = initial_input_error_rate + self.physical_qubit_calculation = physical_qubit_calculation + + self._add_rounds(units) + + @classmethod + def try_create( + cls, + units: Sequence[_DistillationUnit], + initial_input_error_rate: float, + *, + failure_probability_requirement: float = 0.01, + physical_qubit_calculation: Callable[[Iterable], int] = sum, + ) -> Optional[_Pipeline]: + """Create a pipeline if the configuration is feasible. + + Returns: + Optional[_Pipeline]: The pipeline, or None if the required + number of units per round is infeasible. + """ + pipeline = cls( + units, + initial_input_error_rate, + failure_probability_requirement=failure_probability_requirement, + physical_qubit_calculation=physical_qubit_calculation, + ) + if not pipeline._compute_units_per_round(): + return None + return pipeline + + def _compute_units_per_round(self) -> bool: + """Adjust the number of units per round to meet output requirements.""" + if len(self.rounds) > 0: + states_needed_next = self.rounds[-1].unit.num_output_states + + for dist_round in reversed(self.rounds): + if not dist_round.adjust_num_units_to(states_needed_next): + return False + states_needed_next = dist_round.num_input_states + + return True + + def _add_rounds(self, units: Sequence[_DistillationUnit]): + """Append distillation rounds from the given units.""" + per_round_failure_prob_req = self.failure_probability_requirement / len(units) + + for unit in units: + self.rounds.append( + _DistillationRound( + unit, + per_round_failure_prob_req, + self.output_error_rate, + ) + ) + # TODO: handle case when output_error_rate is larger than input_error_rate + self.output_error_rate = unit.error_rate(self.output_error_rate) + + @property + def space(self) -> int: + """Total physical-qubit space of the pipeline.""" + return self.physical_qubit_calculation(round.space for round in self.rounds) + + @property + def time(self) -> int: + """Total time of the pipeline in nanoseconds.""" + return sum(round.unit.time for round in self.rounds) + + @property + def error_rate(self) -> float: + """Output error rate of the pipeline.""" + return self.output_error_rate + + @property + def num_output_states(self) -> int: + """Number of output magic states produced by the pipeline.""" + return self.rounds[-1].compute_num_output_states() + + +@dataclass(slots=True) +class _DistillationUnit: + """A single distillation unit with fixed input/output characteristics.""" + + num_input_states: int + time: int + space: int + error_rate_coeffs: Sequence[float] + failure_probability_coeffs: Sequence[float] + name: Optional[str] = None + num_output_states: int = 1 + + def error_rate(self, input_error_rate: float) -> float: + """Compute the output error rate for a given input error rate.""" + result = 0.0 + for c in self.error_rate_coeffs: + result = result * input_error_rate + c + return result + + def failure_probability(self, input_error_rate: float) -> float: + """Compute the failure probability for a given input error rate.""" + result = 0.0 + for c in self.failure_probability_coeffs: + result = result * input_error_rate + c + return result + + +@dataclass(slots=True) +class _DistillationRound: + """A single round in a distillation pipeline.""" + + unit: _DistillationUnit + failure_probability_requirement: float + input_error_rate: float + num_units: int = 1 + failure_probability: float = field(init=False) + + def __post_init__(self): + self.failure_probability = self.unit.failure_probability(self.input_error_rate) + + def adjust_num_units_to(self, output_states_needed_next: int) -> bool: + """Adjust the number of units to produce at least the required output states.""" + if self.failure_probability == 0.0: + self.num_units = output_states_needed_next + return True + + # Binary search to find the minimal number of units needed + self.num_units = ceil(output_states_needed_next / self.max_num_output_states) + + while True: + num_output_states = self.compute_num_output_states() + if num_output_states < output_states_needed_next: + self.num_units *= 2 + + # Distillation round requires unreasonably high number of units + if self.num_units >= 1_000_000_000_000_000: + return False + else: + break + + upper = self.num_units + lower = self.num_units // 2 + while lower < upper: + self.num_units = (lower + upper) // 2 + num_output_states = self.compute_num_output_states() + if num_output_states >= output_states_needed_next: + upper = self.num_units + else: + lower = self.num_units + 1 + self.num_units = upper + + return True + + @property + def space(self) -> int: + """Total physical-qubit space for this round.""" + return self.num_units * self.unit.space + + @property + def num_input_states(self) -> int: + """Total number of input states consumed by this round.""" + return self.num_units * self.unit.num_input_states + + @property + def max_num_output_states(self) -> int: + """Maximum number of output states this round can produce.""" + return self.num_units * self.unit.num_output_states + + def compute_num_output_states(self) -> int: + """Compute the expected number of output states accounting for failure probability.""" + failure_prob = self.failure_probability + + if failure_prob <= 1e-8: + return self.num_units * self.unit.num_output_states + + # A replacement for SciPy's binom.ppf that is faster + k = _binom_ppf( + self.failure_probability_requirement, + self.num_units, + 1.0 - failure_prob, + ) + + return int(k) * self.unit.num_output_states diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py new file mode 100644 index 0000000000..0fbec26ed7 --- /dev/null +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Generator + +from ..._architecture import ISAContext +from ..._qre import ISARequirements, ISA +from ..._instruction import ISATransform +from ...instruction_ids import ( + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + CCX, + CCY, + CCZ, +) + + +class MagicUpToClifford(ISATransform): + """ + An ISA transform that adds Clifford equivalent representations of magic + states. For example, if the input ISA contains a T gate, the provided ISA + will also contain ``SQRT_SQRT_X``, ``SQRT_SQRT_X_DAG``, ``SQRT_SQRT_Y``, + ``SQRT_SQRT_Y_DAG``, and ``T_DAG``. The same is applied for ``CCZ`` gates and + their Clifford equivalents. + + Example: + + .. code-block:: python + app = SomeApplication() + arch = SomeArchitecture() + + # This will contain CCX states + trace_query = PSSPC.q(ccx_magic_states=True) * LatticeSurgery.q() + + # This will contain CCZ states + isa_query = SurfaceCode.q() * Litinski19Factory.q() + + # There will be no results from the estimation because there is no + # instruction to support CCX magic states in the query + results = estimate(app, arch, isa_query, trace_query) + assert len(results) == 0 + + # We solve this by wrapping the Litinski19Factory with the + # MagicUpToClifford transform, which transforms the CCZ states in the + # provided ISA into CCX states. + isa_query = SurfaceCode.q() * MagicUpToClifford.q(source=Litinski19Factory.q()) + + # Now we will get results + results = estimate(app, arch, isa_query, trace_query) + assert len(results) != 0 + """ + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements() + + def provided_isa(self, impl_isa, ctx: ISAContext) -> Generator[ISA, None, None]: + # Families of equivalent gates under Clifford conjugation. + families = [ + [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + ], + [CCX, CCY, CCZ], + ] + + # For each family, if any member of the family is present in the input ISA, add all members of the family to the provided ISA. + for family in families: + for id in family: + if id in impl_isa: + instr = impl_isa[id] + for equivalent_id in family: + if equivalent_id != id: + node_idx = ctx.add_instruction( + instr.with_id(equivalent_id), + transform=self, + source=[instr], + ) + impl_isa.add_node(equivalent_id, node_idx) + break # Check next family + + yield impl_isa diff --git a/source/pip/qsharp/qre/models/qec/__init__.py b/source/pip/qsharp/qre/models/qec/__init__.py new file mode 100644 index 0000000000..4e4cf816f7 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._surface_code import SurfaceCode +from ._three_aux import ThreeAux +from ._yoked import OneDimensionalYokedSurfaceCode, TwoDimensionalYokedSurfaceCode + +__all__ = [ + "SurfaceCode", + "ThreeAux", + "OneDimensionalYokedSurfaceCode", + "TwoDimensionalYokedSurfaceCode", +] diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py new file mode 100644 index 0000000000..ee5cc8bace --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator +from ..._instruction import ( + ISA, + ISARequirements, + ISATransform, + constraint, + ConstraintBound, + LOGICAL, +) +from ..._isa_enumeration import ISAContext +from ..._qre import linear_function +from ...instruction_ids import CNOT, H, LATTICE_SURGERY, MEAS_Z +from ...property_keys import ( + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, +) + + +@dataclass +class SurfaceCode(ISATransform): + """ + This class models the gate-based rotated surface code. + + Attributes: + crossing_prefactor: float + The prefactor for logical error rate due to error correction + crossings. (Default is 0.03, see Eq. (11) in + [arXiv:1208.0928](https://arxiv.org/abs/1208.0928)) + error_correction_threshold: float + The error correction threshold for the surface code. (Default is + 0.01 (1%), see [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) + one_qubit_gate_depth: int + The depth of one-qubit gates in each syndrome extraction cycle. + (Default is 1, see Fig. 2 in [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) + two_qubit_gate_depth: int + The depth of two-qubit gates in each syndrome extraction cycle. + (Default is 4, see Fig. 2 in [arXiv:1009.3686](https://arxiv.org/abs/1009.3686)) + + Hyper parameters: + distance: int + The code distance of the surface code. + + References: + + - Dominic Horsman, Austin G. Fowler, Simon Devitt, Rodney Van Meter: Surface + code quantum computing by lattice surgery, + [arXiv:1111.4022](https://arxiv.org/abs/1111.4022) + - Austin G. Fowler, Matteo Mariantoni, John M. Martinis, Andrew N. Cleland: + Surface codes: Towards practical large-scale quantum computation, + [arXiv:1208.0928](https://arxiv.org/abs/1208.0928) + - David S. Wang, Austin G. Fowler, Lloyd C. L. Hollenberg: Quantum computing + with nearest neighbor interactions and error rates over 1%, + [arXiv:1009.3686](https://arxiv.org/abs/1009.3686) + """ + + crossing_prefactor: float = 0.03 + error_correction_threshold: float = 0.01 + one_qubit_gate_depth: int = 1 + two_qubit_gate_depth: int = 4 + _: KW_ONLY + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(H, error_rate=ConstraintBound.lt(0.01)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), + constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + cnot = impl_isa[CNOT] + h = impl_isa[H] + meas_z = impl_isa[MEAS_Z] + + cnot_time = cnot.expect_time() + h_time = h.expect_time() + meas_time = meas_z.expect_time() + + physical_error_rate = max( + cnot.expect_error_rate(), + h.expect_error_rate(), + meas_z.expect_error_rate(), + ) + + # There are d^2 data qubits and (d^2 - 1) ancilla qubits in the rotated + # surface code. (See Section 7.1 in arXiv:1111.4022) + space_formula = linear_function(2 * self.distance**2 - 1) + + # Each syndrome extraction cycle consists of ancilla preparation, 4 + # rounds of CNOTs, and measurement. (See Fig. 2 in arXiv:1009.3686); + # these may be modified by the one_qubit_gate_depth and + # two_qubit_gate_depth parameters, or scaled by the time factors + # provided in the instruction properties. The syndrome extraction cycle + # is repeated d times for a distance-d code. + one_qubit_gate_depth = self.one_qubit_gate_depth * h.get_property_or( + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, 1 + ) + two_qubit_gate_depth = self.two_qubit_gate_depth * cnot.get_property_or( + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, 1 + ) + + code_cycle_time = ( + one_qubit_gate_depth * h_time + two_qubit_gate_depth * cnot_time + meas_time + ) + time_value = code_cycle_time * self.distance + + # See Eqs. (10) and (11) in arXiv:1208.0928 + error_formula = linear_function( + self.crossing_prefactor + * ( + (physical_error_rate / self.error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + # We provide a generic lattice surgery instruction (See Section 3 in + # arXiv:1111.4022) + yield ctx.make_isa( + ctx.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + transform=self, + source=[cnot, h, meas_z], + distance=self.distance, + ), + ) diff --git a/source/pip/qsharp/qre/models/qec/_three_aux.py b/source/pip/qsharp/qre/models/qec/_three_aux.py new file mode 100644 index 0000000000..5f7cff6da3 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_three_aux.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator + +from ..._architecture import ISAContext +from ..._instruction import ( + LOGICAL, + ISATransform, + constraint, +) +from ..._qre import ( + ISA, + ISARequirements, + linear_function, +) +from ...instruction_ids import ( + LATTICE_SURGERY, + MEAS_X, + MEAS_XX, + MEAS_Z, + MEAS_ZZ, +) + + +@dataclass +class ThreeAux(ISATransform): + """ + This class models the pairwise measurement-based surface code with three + auxiliary qubits per stabilizer measurement. + + Hyper parameters: + distance: int + The code distance of the surface code. + single_rail: bool + Whether to use single-rail encoding. + + References: + + - Linnea Grans-Samuelsson, Ryan V. Mishmash, David Aasen, Christina Knapp, + Bela Bauer, Brad Lackey, Marcus P. da Silva, Parsa Bonderson: Improved + Pairwise Measurement-Based Surface Code, + [arXiv:2310.12981](https://arxiv.org/abs/2310.12981) + """ + + _: KW_ONLY + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) + single_rail: bool = field(default=False) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(MEAS_X), + constraint(MEAS_Z), + constraint(MEAS_XX, arity=2), + constraint(MEAS_ZZ, arity=2), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + meas_x = impl_isa[MEAS_X] + meas_z = impl_isa[MEAS_Z] + meas_xx = impl_isa[MEAS_XX] + meas_zz = impl_isa[MEAS_ZZ] + + gate_time = max(meas_xx.expect_time(), meas_zz.expect_time()) + + physical_error_rate = max( + meas_x.expect_error_rate(), + meas_z.expect_error_rate(), + meas_xx.expect_error_rate(), + meas_zz.expect_error_rate(), + ) + + # See arXiv:2310.12981, Table 1 and Figs. 2, 3, 4, 6, and 7 + depth = 5 if self.single_rail else 4 + + # See arXiv:2310.12981, Table 1 + error_correction_threshold = 0.0051 if self.single_rail else 0.0066 + + # See arXiv:2310.12981, Fig. 23 + crossing_prefactor = 0.05 + + # d^2 data qubits and 3 qubits for each of the d^2 - 1 stabilizer + # measurements + space_formula = linear_function(4 * self.distance**2 - 3) + + # The measurement circuits do not overlap perfectly, so there is an + # additional 4 steps that need to be accounted for independent of the + # distance (see Section 2 between Eqs. (2) and (3) in arXiv:2310.12981) + time_value = gate_time * (depth * self.distance + 4) + + # Typical fitting curve for surface code logical error (see + # arXiv:1208.0928) + error_formula = linear_function( + crossing_prefactor + * ( + (physical_error_rate / error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + yield ctx.make_isa( + ctx.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + transform=self, + source=[meas_x, meas_z, meas_xx, meas_zz], + distance=self.distance, + ) + ) diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py new file mode 100644 index 0000000000..9cb1b26527 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass +from math import ceil +from typing import Generator + +from ..._instruction import ISATransform, constraint, LOGICAL +from ..._qre import ISA, ISARequirements, generic_function +from ..._architecture import ISAContext +from ...instruction_ids import LATTICE_SURGERY, MEMORY +from ...property_keys import DISTANCE + + +@dataclass +class OneDimensionalYokedSurfaceCode(ISATransform): + """ + This class models the Yoked surface code to provide a generic memory + instruction based on lattice surgery instructions from a surface code like + error correction code. + + Attributes: + crossing_prefactor: float + The prefactor for logical error rate (Default is 0.016) + error_correction_threshold: float + The error correction threshold for the surface code (Default is + 0.064) + + Hyper parameters: + shape_heuristic: ShapeHeuristic + The heuristic to determine the shape of the surface code patch for a + given number of logical qubits. (Default is ShapeHeuristic.MIN_AREA) + + References: + + - Craig Gidney, Michael Newman, Peter Brooks, Cody Jones: Yoked surface + codes, [arXiv:2312.04522](https://arxiv.org/abs/2312.04522) + """ + + # NOTE: The crossing_prefactor is relative to that of the underlying surface + # code. That is if the surface code model is p(SC) = + # A*(p(phy)/th(SC))^((d+1)/2), then multiplier for its yoked extension is + # crossing_prefactor*A + crossing_prefactor: float = 8 / 15 + + # NOTE: The threshold is relative to that of the underlying surface code. + # Namely, as the yoking doubles the distance, one would expect the yoked + # surface code to have a threshold of sqrt(th(SC)). However modeling shows + # it falls short of this. + error_correction_threshold: float = 64 / 10 + + @staticmethod + def required_isa() -> ISARequirements: + # We require a lattice surgery instruction that also provides the code + # distance as a property. This is necessary to compute the time + # and error rate formulas for the provided memory instruction. + return ISARequirements( + constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + lattice_surgery = impl_isa[LATTICE_SURGERY] + distance = lattice_surgery.get_property(DISTANCE) + assert distance is not None + + def space(arity: int) -> int: + a, b = self._min_area_shape(arity) + return lattice_surgery.expect_space(a * b) + + space_fn = generic_function(space) + + def time(arity: int) -> int: + a, b = self._min_area_shape(arity) + s = lattice_surgery.expect_time(a * b) + return s * (8 * distance * (a - 1) + 2 * distance) + + time_fn = generic_function(time) + + def error_rate(arity: int) -> float: + a, b = self._min_area_shape(arity) + rounds = 2 * (a - 2) + # logical error rate on a single surface code patch + p = lattice_surgery.expect_error_rate(1) + return ( + rounds**2 + * (a * b) ** 2 + * self.crossing_prefactor + * p + * (1 / self.error_correction_threshold) ** ((distance + 1) // 2) + ) + + error_rate_fn = generic_function(error_rate) + + yield ctx.make_isa( + ctx.add_instruction( + MEMORY, + arity=None, + encoding=LOGICAL, + space=space_fn, + time=time_fn, + error_rate=error_rate_fn, + transform=self, + source=[lattice_surgery], + distance=distance, + ) + ) + + @staticmethod + def _min_area_shape(num_qubits: int) -> tuple[int, int]: + """ + Given a number of qubits num_qubits, returns numbers (a + 1) and (b + 2) + such that a * b >= num_qubits and a * b is as small as possible. + """ + + best_a = None + best_b = None + best_qubits = num_qubits**2 + + for a in range(1, num_qubits): + # Compute required number of columns to reach the required number + # of logical qubits + b = ceil(num_qubits / a) + + qubits = (a + 1) * (b + 2) + if qubits < best_qubits: + best_qubits = qubits + best_a = a + best_b = b + + assert best_a is not None + assert best_b is not None + return best_a + 1, best_b + 2 + + +@dataclass +class TwoDimensionalYokedSurfaceCode(ISATransform): + """ + This class models the Yoked surface code to provide a generic memory + instruction based on lattice surgery instructions from a surface code like + error correction code. + + Attributes: + crossing_prefactor: float + The prefactor for logical error rate (Default is 0.016) + error_correction_threshold: float + The error correction threshold for the surface code (Default is + 0.064) + + Hyper parameters: + shape_heuristic: ShapeHeuristic + The heuristic to determine the shape of the surface code patch for a + given number of logical qubits. (Default is ShapeHeuristic.MIN_AREA) + + References: + + - Craig Gidney, Michael Newman, Peter Brooks, Cody Jones: Yoked surface + codes, [arXiv:2312.04522](https://arxiv.org/abs/2312.04522) + """ + + # NOTE: The crossing_prefactor is relative to that of the underlying surface + # code. That is if the surface code model is p(SC) = + # A*(p(phy)/th(SC))^((d+1)/2), then multiplier for its yoked extension is + # crossing_prefactor*A + crossing_prefactor: float = 5 / 600 + + # NOTE: The threshold is relative to that of the underlying surface code. + # Namely, as the yoking doubles the distance, one would expect the yoked + # surface code to have a threshold of sqrt(th(SC)). However modeling shows + # it falls short of this. + error_correction_threshold: float = 2500 / 10 + + @staticmethod + def required_isa() -> ISARequirements: + # We require a lattice surgery instruction that also provides the code + # distance as a property. This is necessary to compute the time + # and error rate formulas for the provided memory instruction. + return ISARequirements( + constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + lattice_surgery = impl_isa[LATTICE_SURGERY] + distance = lattice_surgery.get_property(DISTANCE) + assert distance is not None + + def space(arity: int) -> int: + a, b = self._square_shape(arity) + return lattice_surgery.expect_space(a * b) + + space_fn = generic_function(space) + + def time(arity: int) -> int: + a, b = self._square_shape(arity) + s = lattice_surgery.expect_time(a * b) + return s * (8 * distance * max(a - 2, b - 2) + 2 * distance) + + time_fn = generic_function(time) + + def error_rate(arity: int) -> float: + a, b = self._square_shape(arity) + rounds = 2 * max(a - 3, b - 3) + # logical error rate on a single surface code patch + p = lattice_surgery.expect_error_rate(1) + return ( + rounds**4 + * (a * b) ** 2 + * self.crossing_prefactor + * p + * (1 / self.error_correction_threshold) ** ((distance + 1) // 2) + ) + + error_rate_fn = generic_function(error_rate) + + yield ctx.make_isa( + ctx.add_instruction( + MEMORY, + arity=None, + encoding=LOGICAL, + space=space_fn, + time=time_fn, + error_rate=error_rate_fn, + transform=self, + source=[lattice_surgery], + distance=distance, + ) + ) + + @staticmethod + def _square_shape(num_qubits: int) -> tuple[int, int]: + """ + Given a number of qubits num_qubits, returns numbers (a + 2) and (b + 2) + such that a * b >= num_qubits and a and b are as close as possible. + """ + + a = int(num_qubits**0.5) + while num_qubits % a != 0: + a -= 1 + b = num_qubits // a + return a + 2, b + 2 diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py new file mode 100644 index 0000000000..ab7887faf3 --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._gate_based import GateBased +from ._msft import Majorana + +__all__ = ["GateBased", "Majorana"] diff --git a/source/pip/qsharp/qre/models/qubits/_gate_based.py b/source/pip/qsharp/qre/models/qubits/_gate_based.py new file mode 100644 index 0000000000..d9ee589485 --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/_gate_based.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from typing import Optional + +from ..._architecture import Architecture, ISAContext +from ..._instruction import ISA, Encoding +from ...instruction_ids import ( + CNOT, + CZ, + MEAS_X, + MEAS_Y, + MEAS_Z, + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + RX, + RY, + RZ, + S_DAG, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + T_DAG, + H, + S, + T, +) + + +@dataclass +class GateBased(Architecture): + """ + A generic gate-based architecture. The error rate can be set arbitrarily + and is either 1e-3 or 1e-4 in the reference. + + Args: + error_rate: The error rate for all gates. Defaults to 1e-4. + gate_time: The time (in ns) for single-qubit gates. + measurement_time: The time (in ns) for measurement operations. + two_qubit_gate_time: The time (in ns) for two-qubit gates (CNOT, CZ). + If not provided, defaults to the value of ``gate_time``. + + References: + + - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, + Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi + Sundaram, Alexander Vaschillo: Assessing requirements to scale to + practical quantum advantage, + [arXiv:2211.07629](https://arxiv.org/abs/2211.07629) + - Jens Koch, Terri M. Yu, Jay Gambetta, A. A. Houck, D. I. Schuster, J. + Majer, Alexandre Blais, M. H. Devoret, S. M. Girvin, R. J. Schoelkopf: + Charge insensitive qubit design derived from the Cooper pair box, + [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002) + """ + + _: KW_ONLY + error_rate: float = field(default=1e-4) + gate_time: int + measurement_time: int + two_qubit_gate_time: Optional[int] = field(default=None) + + def __post_init__(self): + if self.two_qubit_gate_time is None: + self.two_qubit_gate_time = self.gate_time + + def provided_isa(self, ctx: ISAContext) -> ISA: + # Value is initialized in __post_init__ + assert self.two_qubit_gate_time is not None + + # NOTE: This can be improved with instruction coercion once implemented. + instructions = [] + + # Single-qubit gates + single = [ + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + H, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + S, + S_DAG, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + T, + T_DAG, + RX, + RY, + RZ, + ] + + for instr in single: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=1, + time=self.gate_time, + error_rate=self.error_rate, + ) + ) + + for instr in [MEAS_X, MEAS_Y, MEAS_Z]: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=1, + time=self.measurement_time, + error_rate=self.error_rate, + ) + ) + + # Two-qubit gates + for instr in [CNOT, CZ]: + instructions.append( + ctx.add_instruction( + instr, + encoding=Encoding.PHYSICAL, + arity=2, + time=self.two_qubit_gate_time, + error_rate=self.error_rate, + ) + ) + + return ctx.make_isa(*instructions) diff --git a/source/pip/qsharp/qre/models/qubits/_msft.py b/source/pip/qsharp/qre/models/qubits/_msft.py new file mode 100644 index 0000000000..1d74300e3e --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/_msft.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field + +from ..._architecture import Architecture, ISAContext +from ...instruction_ids import ( + T, + PREP_X, + PREP_Z, + MEAS_XX, + MEAS_ZZ, + MEAS_X, + MEAS_Z, +) +from ..._instruction import ISA + + +@dataclass +class Majorana(Architecture): + """ + This class models physical instructions that may be relevant for future + Majorana qubits. For these qubits, we assume that measurements + and the physical T gate each take 1 µs. Owing to topological protection in + the hardware, we assume single and two-qubit measurement error rates + (Clifford error rates) in $10^{-4}$, $10^{-5}$, and $10^{-6}$ as a range + between realistic and optimistic targets. Non-Clifford operations in this + architecture do not have topological protection, so we assume a 5%, 1.5%, + and 1% error rate for non-Clifford physical T gates for the three cases, + respectively. + + References: + + - Torsten Karzig, Christina Knapp, Roman M. Lutchyn, Parsa Bonderson, + Matthew B. Hastings, Chetan Nayak, Jason Alicea, Karsten Flensberg, + Stephan Plugge, Yuval Oreg, Charles M. Marcus, Michael H. Freedman: + Scalable Designs for Quasiparticle-Poisoning-Protected Topological Quantum + Computation with Majorana Zero Modes, + [arXiv:1610.05289](https://arxiv.org/abs/1610.05289) + - Alexei Kitaev: Unpaired Majorana fermions in quantum wires, + [arXiv:cond-mat/0010440](https://arxiv.org/abs/cond-mat/0010440) + - Sankar Das Sarma, Michael Freedman, Chetan Nayak: Majorana Zero Modes and + Topological Quantum Computation, + [arXiv:1501.02813](https://arxiv.org/abs/1501.02813) + """ + + _: KW_ONLY + error_rate: float = field(default=1e-5, metadata={"domain": [1e-4, 1e-5, 1e-6]}) + + def provided_isa(self, ctx: ISAContext) -> ISA: + if abs(self.error_rate - 1e-4) <= 1e-8: + t_error_rate = 0.05 + elif abs(self.error_rate - 1e-5) <= 1e-8: + t_error_rate = 0.015 + elif abs(self.error_rate - 1e-6) <= 1e-8: + t_error_rate = 0.01 + + return ctx.make_isa( + ctx.add_instruction(PREP_X, time=1000, error_rate=self.error_rate), + ctx.add_instruction(PREP_Z, time=1000, error_rate=self.error_rate), + ctx.add_instruction( + MEAS_XX, arity=2, time=1000, error_rate=self.error_rate + ), + ctx.add_instruction( + MEAS_ZZ, arity=2, time=1000, error_rate=self.error_rate + ), + ctx.add_instruction(MEAS_X, time=1000, error_rate=self.error_rate), + ctx.add_instruction(MEAS_Z, time=1000, error_rate=self.error_rate), + ctx.add_instruction(T, time=1000, error_rate=t_error_rate), + ) diff --git a/source/pip/qsharp/qre/property_keys.py b/source/pip/qsharp/qre/property_keys.py new file mode 100644 index 0000000000..917e25ca0d --- /dev/null +++ b/source/pip/qsharp/qre/property_keys.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportAttributeAccessIssue=false + + +from .._native import property_keys + +for name in property_keys.__all__: + globals()[name] = getattr(property_keys, name) diff --git a/source/pip/qsharp/qre/property_keys.pyi b/source/pip/qsharp/qre/property_keys.pyi new file mode 100644 index 0000000000..0e0e2358b4 --- /dev/null +++ b/source/pip/qsharp/qre/property_keys.pyi @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +DISTANCE: int +SURFACE_CODE_ONE_QUBIT_TIME_FACTOR: int +SURFACE_CODE_TWO_QUBIT_TIME_FACTOR: int +ACCELERATION: int +NUM_TS_PER_ROTATION: int +EXPECTED_SHOTS: int +RUNTIME_SINGLE_SHOT: int +EVALUATION_TIME: int +PHYSICAL_COMPUTE_QUBITS: int +PHYSICAL_FACTORY_QUBITS: int +PHYSICAL_MEMORY_QUBITS: int +MOLECULE: int +LOGICAL_COMPUTE_QUBITS: int +LOGICAL_MEMORY_QUBITS: int +ALGORITHM_COMPUTE_QUBITS: int +ALGORITHM_MEMORY_QUBITS: int +NAME: int diff --git a/source/pip/src/interpreter.rs b/source/pip/src/interpreter.rs index d9272eaa2f..9fff6f80d2 100644 --- a/source/pip/src/interpreter.rs +++ b/source/pip/src/interpreter.rs @@ -30,6 +30,7 @@ use crate::{ }, unbind_noise_config, }, + qre::register_qre_submodule, }; use miette::{Diagnostic, Report}; use num_bigint::BigUint; @@ -139,6 +140,7 @@ fn _native<'a>(py: Python<'a>, m: &Bound<'a, PyModule>) -> PyResult<()> { m.add("QSharpError", py.get_type::())?; register_noisy_simulator_submodule(py, m)?; register_generic_estimator_submodule(m)?; + register_qre_submodule(m)?; // QASM interop m.add("QasmError", py.get_type::())?; m.add_function(wrap_pyfunction!(resource_estimate_qasm_program, m)?)?; diff --git a/source/pip/src/lib.rs b/source/pip/src/lib.rs index 72edafb7ca..b962d97df5 100644 --- a/source/pip/src/lib.rs +++ b/source/pip/src/lib.rs @@ -10,3 +10,4 @@ mod interop; mod interpreter; mod noisy_simulator; mod qir_simulation; +mod qre; diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs new file mode 100644 index 0000000000..0e9daa1686 --- /dev/null +++ b/source/pip/src/qre.rs @@ -0,0 +1,1624 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + ptr::NonNull, + sync::{Arc, RwLock}, +}; + +use pyo3::{ + IntoPyObjectExt, + exceptions::{PyException, PyKeyError, PyRuntimeError, PyTypeError, PyValueError}, + prelude::*, + types::{PyBool, PyDict, PyFloat, PyInt, PyString, PyTuple, PyType}, +}; +use qre::TraceTransform; +use serde::{Deserialize, Serialize}; + +pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(constant_function, m)?)?; + m.add_function(wrap_pyfunction!(linear_function, m)?)?; + m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; + m.add_function(wrap_pyfunction!(generic_function, m)?)?; + m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; + m.add_function(wrap_pyfunction!(estimate_with_graph, m)?)?; + m.add_function(wrap_pyfunction!(binom_ppf, m)?)?; + m.add_function(wrap_pyfunction!(float_to_bits, m)?)?; + m.add_function(wrap_pyfunction!(float_from_bits, m)?)?; + m.add_function(wrap_pyfunction!(instruction_name, m)?)?; + m.add_function(wrap_pyfunction!(property_name_to_key, m)?)?; + m.add_function(wrap_pyfunction!(property_name, m)?)?; + + m.add("EstimationError", m.py().get_type::())?; + + add_instruction_ids(m)?; + add_property_keys(m)?; + + Ok(()) +} + +pyo3::create_exception!(qsharp.qre, EstimationError, PyException); + +fn poisoned_lock_err(_: std::sync::PoisonError) -> PyErr { + PyRuntimeError::new_err("provenance graph lock poisoned") +} + +#[allow(clippy::upper_case_acronyms)] +#[pyclass] +pub struct ISA(qre::ISA); + +impl ISA { + pub fn inner(&self) -> &qre::ISA { + &self.0 + } +} + +#[pymethods] +impl ISA { + pub fn __add__(&self, other: &ISA) -> PyResult { + Ok(ISA(self.0.clone() + other.0.clone())) + } + + pub fn __contains__(&self, id: u64) -> bool { + self.0.contains(&id) + } + + pub fn satisfies(&self, requirements: &ISARequirements) -> PyResult { + Ok(self.0.satisfies(&requirements.0)) + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + pub fn __getitem__(&self, id: u64) -> PyResult { + match self.0.get(&id) { + Some(instr) => Ok(Instruction(instr)), + None => Err(PyKeyError::new_err(format!( + "Instruction with id {id} not found" + ))), + } + } + + #[pyo3(signature = (id, default=None))] + pub fn get(&self, id: u64, default: Option<&Instruction>) -> Option { + match self.0.get(&id) { + Some(instr) => Some(Instruction(instr)), + None => default.cloned(), + } + } + + /// Returns the provenance graph node index for the given instruction ID. + pub fn node_index(&self, id: u64) -> Option { + self.0.node_index(&id) + } + + /// Adds a node (by instruction ID and node index) that already exists in the graph. + pub fn add_node(&mut self, instruction_id: u64, node_index: usize) { + self.0.add_node(instruction_id, node_index); + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let instructions = slf.0.instructions(); + let iter = ISAIterator { + iter: instructions.into_iter(), + }; + Py::new(slf.py(), iter) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct ISAIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl ISAIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(Instruction) + } +} + +#[pyclass] +pub struct ISARequirements(qre::ISARequirements); + +#[pymethods] +impl ISARequirements { + #[new] + #[pyo3(signature = (*constraints))] + pub fn new(constraints: &Bound<'_, PyTuple>) -> PyResult { + if constraints.len() == 1 { + let item = constraints.get_item(0)?; + if let Ok(seq) = item.cast::() { + let mut instrs = Vec::with_capacity(seq.len()); + for item in seq.iter() { + let instr = item.cast_into::()?; + instrs.push(instr.borrow().0.clone()); + } + return Ok(ISARequirements(instrs.into_iter().collect())); + } + } + + constraints + .into_iter() + .map(|instr| { + let instr = instr.cast_into::()?; + Ok(instr.borrow().0.clone()) + }) + .collect::>() + .map(ISARequirements) + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let constraints: Vec = slf.0.constraints(); + let iter = ISARequirementsIterator { + iter: constraints.into_iter(), + }; + Py::new(slf.py(), iter) + } +} + +#[pyclass] +pub struct ISARequirementsIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl ISARequirementsIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(Constraint) + } +} + +#[allow(clippy::unsafe_derive_deserialize)] +#[pyclass(from_py_object)] +#[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Instruction(qre::Instruction); + +#[pymethods] +impl Instruction { + #[staticmethod] + pub fn fixed_arity( + id: u64, + encoding: u64, + arity: u64, + time: u64, + space: Option, + length: Option, + error_rate: f64, + ) -> PyResult { + Ok(Instruction(qre::Instruction::fixed_arity( + id, + convert_encoding(encoding)?, + arity, + time, + space, + length, + error_rate, + ))) + } + + #[staticmethod] + pub fn variable_arity( + id: u64, + encoding: u64, + time_fn: &IntFunction, + space_fn: &IntFunction, + error_rate_fn: &FloatFunction, + length_fn: Option<&IntFunction>, + ) -> PyResult { + Ok(Instruction(qre::Instruction::variable_arity( + id, + convert_encoding(encoding)?, + time_fn.0.clone(), + space_fn.0.clone(), + length_fn.map(|f| f.0.clone()), + error_rate_fn.0.clone(), + ))) + } + + pub fn with_id(&self, id: u64) -> Self { + Instruction(self.0.with_id(id)) + } + + #[getter] + pub fn id(&self) -> u64 { + self.0.id() + } + + #[getter] + pub fn encoding(&self) -> u64 { + match self.0.encoding() { + qre::Encoding::Physical => 0, + qre::Encoding::Logical => 1, + } + } + + #[getter] + pub fn arity(&self) -> Option { + self.0.arity() + } + + #[pyo3(signature = (arity=None))] + pub fn space(&self, arity: Option) -> Option { + self.0.space(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn time(&self, arity: Option) -> Option { + self.0.time(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn error_rate(&self, arity: Option) -> Option { + self.0.error_rate(arity) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_space(&self, arity: Option) -> PyResult { + Ok(self.0.expect_space(arity)) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_time(&self, arity: Option) -> PyResult { + Ok(self.0.expect_time(arity)) + } + + #[pyo3(signature = (arity=None))] + pub fn expect_error_rate(&self, arity: Option) -> PyResult { + Ok(self.0.expect_error_rate(arity)) + } + + pub fn set_source(&mut self, index: usize) { + self.0.set_source(index); + } + + #[getter] + pub fn source(&self) -> usize { + self.0.source() + } + + pub fn set_property(&mut self, key: u64, value: u64) { + self.0.set_property(key, value); + } + + pub fn get_property(&self, key: u64) -> Option { + self.0.get_property(&key) + } + + pub fn has_property(&self, key: u64) -> bool { + self.0.has_property(&key) + } + + #[pyo3(signature = (key, default))] + pub fn get_property_or(&self, key: u64, default: u64) -> u64 { + self.0.get_property_or(&key, default) + } + + pub fn __getitem__(&self, key: u64) -> PyResult { + match self.0.get_property(&key) { + Some(value) => Ok(value), + None => Err(PyKeyError::new_err(format!( + "Property with key {key} not found" + ))), + } + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +impl qre::ParetoItem2D for Instruction { + type Objective1 = u64; + type Objective2 = u64; + + fn objective1(&self) -> Self::Objective1 { + self.0 + .space(None) + .unwrap_or_else(|| self.0.expect_space(Some(1))) + } + + fn objective2(&self) -> Self::Objective2 { + self.0 + .time(None) + .unwrap_or_else(|| self.0.expect_time(Some(1))) + } +} + +impl qre::ParetoItem3D for Instruction { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.0 + .space(None) + .unwrap_or_else(|| self.0.expect_space(Some(1))) + } + + fn objective2(&self) -> Self::Objective2 { + self.0 + .time(None) + .unwrap_or_else(|| self.0.expect_time(Some(1))) + } + + fn objective3(&self) -> Self::Objective3 { + self.0 + .error_rate(None) + .unwrap_or_else(|| self.0.expect_error_rate(Some(1))) + } +} + +#[pyclass] +pub struct Constraint(qre::InstructionConstraint); + +#[pymethods] +impl Constraint { + #[new] + pub fn new( + id: u64, + encoding: u64, + arity: Option, + error_rate: Option<&ConstraintBound>, + ) -> PyResult { + Ok(Constraint(qre::InstructionConstraint::new( + id, + convert_encoding(encoding)?, + arity, + error_rate.map(|error_rate| error_rate.0), + ))) + } + + #[getter] + pub fn id(&self) -> u64 { + self.0.id() + } + + #[getter] + pub fn encoding(&self) -> u64 { + match self.0.encoding() { + qre::Encoding::Physical => 0, + qre::Encoding::Logical => 1, + } + } + + #[getter] + pub fn arity(&self) -> Option { + self.0.arity() + } + + #[getter] + pub fn error_rate(&self) -> Option { + self.0.error_rate().copied().map(ConstraintBound) + } + + pub fn add_property(&mut self, property: u64) { + self.0.add_property(property); + } + + pub fn has_property(&self, property: u64) -> bool { + self.0.has_property(&property) + } +} + +fn convert_encoding(encoding: u64) -> PyResult { + match encoding { + 0 => Ok(qre::Encoding::Physical), + 1 => Ok(qre::Encoding::Logical), + _ => Err(EstimationError::new_err("Invalid encoding value")), + } +} + +/// Build a `qre::Instruction` from either an existing `Instruction` Python +/// object or from keyword arguments (id + encoding + arity + …). +#[allow(clippy::too_many_arguments)] +fn build_instruction( + id_or_instruction: &Bound<'_, PyAny>, + encoding: u64, + arity: Option, + time: Option<&Bound<'_, PyAny>>, + space: Option<&Bound<'_, PyAny>>, + length: Option<&Bound<'_, PyAny>>, + error_rate: Option<&Bound<'_, PyAny>>, + kwargs: Option<&Bound<'_, PyDict>>, +) -> PyResult { + // If the first argument is already an Instruction, return its inner clone. + if let Ok(instr) = id_or_instruction.cast::() { + return Ok(instr.borrow().0.clone()); + } + + // Otherwise treat the first arg as an instruction ID (int). + let id: u64 = id_or_instruction.extract()?; + let enc = convert_encoding(encoding)?; + + let mut instr = if let Some(arity_val) = arity { + // Fixed-arity path + let time_val: u64 = time + .ok_or_else(|| PyTypeError::new_err("'time' is required"))? + .extract()?; + let space_val: Option = space.map(PyAnyMethods::extract).transpose()?; + let length_val: Option = length.map(PyAnyMethods::extract).transpose()?; + let error_rate_val: f64 = error_rate + .ok_or_else(|| PyTypeError::new_err("'error_rate' is required"))? + .extract()?; + qre::Instruction::fixed_arity( + id, + enc, + arity_val, + time_val, + space_val, + length_val, + error_rate_val, + ) + } else { + // Variable-arity path: time/space/error_rate may be functions + let time_fn = + extract_int_function(time.ok_or_else(|| PyTypeError::new_err("'time' is required"))?)?; + let space_fn = extract_int_function( + space.ok_or_else(|| PyTypeError::new_err("'space' is required for variable arity"))?, + )?; + let length_fn = length.map(extract_int_function).transpose()?; + let error_rate_fn = extract_float_function( + error_rate.ok_or_else(|| PyTypeError::new_err("'error_rate' is required"))?, + )?; + qre::Instruction::variable_arity(id, enc, time_fn, space_fn, length_fn, error_rate_fn) + }; + + // Apply additional properties from kwargs + if let Some(kw) = kwargs { + for (key, value) in kw { + let key_str: String = key.extract()?; + let prop_key = + qre::property_name_to_key(&key_str.to_ascii_uppercase()).ok_or_else(|| { + PyValueError::new_err(format!("Unknown property name: {key_str}")) + })?; + let prop_value: u64 = value.extract()?; + instr.set_property(prop_key, prop_value); + } + } + + Ok(instr) +} + +/// Extract an `_IntFunction` or convert a plain int into a constant function. +fn extract_int_function(obj: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(f) = obj.cast::() { + return Ok(f.borrow().0.clone()); + } + let val: u64 = obj.extract()?; + Ok(qre::VariableArityFunction::constant(val)) +} + +/// Extract a `_FloatFunction` or convert a plain float into a constant function. +fn extract_float_function(obj: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(f) = obj.cast::() { + return Ok(f.borrow().0.clone()); + } + let val: f64 = obj.extract()?; + Ok(qre::VariableArityFunction::constant(val)) +} + +#[pyclass] +pub struct ConstraintBound(qre::ConstraintBound); + +#[pymethods] +impl ConstraintBound { + #[staticmethod] + pub fn lt(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::less_than(value)) + } + + #[staticmethod] + pub fn le(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::less_equal(value)) + } + + #[staticmethod] + pub fn eq(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::equal(value)) + } + + #[staticmethod] + pub fn gt(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::greater_than(value)) + } + + #[staticmethod] + pub fn ge(value: f64) -> ConstraintBound { + ConstraintBound(qre::ConstraintBound::greater_equal(value)) + } +} + +#[derive(Clone)] +#[pyclass(name = "_ProvenanceGraph", from_py_object)] +pub struct ProvenanceGraph(Arc>); + +impl Default for ProvenanceGraph { + fn default() -> Self { + Self(Arc::new(RwLock::new(qre::ProvenanceGraph::new()))) + } +} + +#[pymethods] +impl ProvenanceGraph { + #[new] + pub fn new() -> Self { + Self::default() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn add_node( + &mut self, + instruction: &Instruction, + transform: u64, + children: Vec, + ) -> PyResult { + Ok(self.0.write().map_err(poisoned_lock_err)?.add_node( + instruction.0.clone(), + transform, + &children, + )) + } + + pub fn instruction(&self, node_index: usize) -> PyResult { + Ok(Instruction( + self.0 + .read() + .map_err(poisoned_lock_err)? + .instruction(node_index) + .clone(), + )) + } + + pub fn transform_id(&self, node_index: usize) -> PyResult { + Ok(self + .0 + .read() + .map_err(poisoned_lock_err)? + .transform_id(node_index)) + } + + pub fn children(&self, node_index: usize) -> PyResult> { + Ok(self + .0 + .read() + .map_err(poisoned_lock_err)? + .children(node_index) + .to_vec()) + } + + pub fn num_nodes(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.num_nodes()) + } + + pub fn num_edges(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.num_edges()) + } + + /// Adds an instruction to the provenance graph with no transform or children. + /// Accepts either a pre-existing Instruction or keyword args to create one. + /// Returns the node index. + #[pyo3(signature = (id_or_instruction, encoding=0, *, arity=Some(1), time=None, space=None, length=None, error_rate=None, **kwargs))] + #[allow(clippy::too_many_arguments)] + pub fn add_instruction( + &mut self, + id_or_instruction: &Bound<'_, PyAny>, + encoding: u64, + arity: Option, + time: Option<&Bound<'_, PyAny>>, + space: Option<&Bound<'_, PyAny>>, + length: Option<&Bound<'_, PyAny>>, + error_rate: Option<&Bound<'_, PyAny>>, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult { + let instr = build_instruction( + id_or_instruction, + encoding, + arity, + time, + space, + length, + error_rate, + kwargs, + )?; + Ok(self + .0 + .write() + .map_err(poisoned_lock_err)? + .add_node(instr, 0, &[])) + } + + /// Creates an ISA backed by this provenance graph from the given node indices. + pub fn make_isa(&self, node_indices: Vec) -> PyResult { + let graph = self.0.read().map_err(poisoned_lock_err)?; + let mut isa = qre::ISA::with_graph(self.0.clone()); + for idx in node_indices { + let id = graph.instruction(idx).id(); + isa.add_node(id, idx); + } + Ok(ISA(isa)) + } + + /// Builds the per-instruction-ID Pareto index. + /// + /// Must be called after all nodes have been added. For each instruction + /// ID, retains only the Pareto-optimal nodes w.r.t. (space, time, + /// error rate) evaluated at arity 1. + pub fn build_pareto_index(&self) -> PyResult<()> { + self.0 + .write() + .map_err(poisoned_lock_err)? + .build_pareto_index(); + Ok(()) + } + + /// Returns the raw node count (including the sentinel at index 0). + pub fn raw_node_count(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.raw_node_count()) + } + + /// Computes an upper bound on the possible ISAs that can be formed from + /// this graph. + /// + /// Must be called after `build_pareto_index`. + pub fn total_isa_count(&self) -> PyResult { + Ok(self.0.read().map_err(poisoned_lock_err)?.total_isa_count()) + } + + /// Returns ISAs formed from Pareto-optimal graph nodes satisfying the + /// given requirements. + /// + /// For each constraint in `requirements`, selects matching Pareto-optimal + /// nodes. Returns the Cartesian product of per-constraint matches, + /// augmented with one representative node per unconstrained instruction + /// ID. + /// + /// When ``min_node_idx`` is provided, only Pareto nodes at or above + /// that index are considered for constrained groups (useful for scoping + /// queries to a subset of the graph). + /// + /// Must be called after `build_pareto_index`. + #[pyo3(signature = (requirements, min_node_idx=None))] + pub fn query_satisfying( + &self, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> PyResult> { + let graph = self.0.read().map_err(poisoned_lock_err)?; + Ok(graph + .query_satisfying(&self.0, &requirements.0, min_node_idx) + .into_iter() + .map(ISA) + .collect()) + } +} + +#[pyclass(name = "_IntFunction")] +pub struct IntFunction(qre::VariableArityFunction); + +#[pyclass(name = "_FloatFunction")] +pub struct FloatFunction(qre::VariableArityFunction); + +#[pyfunction] +pub fn constant_function<'py>(value: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(v) = value.extract::() { + IntFunction(qre::VariableArityFunction::Constant { value: v }).into_bound_py_any(value.py()) + } else if let Ok(v) = value.extract::() { + FloatFunction(qre::VariableArityFunction::Constant { value: v }) + .into_bound_py_any(value.py()) + } else { + Err(PyTypeError::new_err( + "Value must be either an integer or a float", + )) + } +} + +#[pyfunction] +pub fn linear_function<'py>(slope: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(s) = slope.extract::() { + IntFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) + } else if let Ok(s) = slope.extract::() { + FloatFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) + } else { + Err(PyTypeError::new_err( + "Slope must be either an integer or a float", + )) + } +} + +// TODO: Assign default value to offset? +#[pyfunction] +#[pyo3(signature = (block_size, slope, offset))] +pub fn block_linear_function<'py>( + block_size: u64, + slope: &Bound<'py, PyAny>, + offset: &Bound<'py, PyAny>, +) -> PyResult> { + if let Ok(s) = slope.extract() { + let offset = offset.extract::()?; + IntFunction(qre::VariableArityFunction::block_linear( + block_size, s, offset, + )) + .into_bound_py_any(slope.py()) + } else if let Ok(s) = slope.extract::() { + let offset = offset.extract()?; + FloatFunction(qre::VariableArityFunction::block_linear( + block_size, s, offset, + )) + .into_bound_py_any(slope.py()) + } else { + Err(PyTypeError::new_err( + "Slope must be either an integer or a float", + )) + } +} + +#[pyfunction] +pub fn generic_function<'py>( + py: Python<'py>, + func: Bound<'py, PyAny>, +) -> PyResult> { + // Try to get return type annotation from the function + let is_int = if let Ok(annotations) = func.getattr("__annotations__") { + if let Ok(return_type) = annotations.get_item("return") { + // Check if return type is float + let float_type = py.get_type::(); + return_type.eq(float_type).unwrap_or(false) + } else { + false + } + } else { + false + }; + + let func: Py = func.unbind(); + + if is_int { + let closure = move |arity: u64| -> u64 { + Python::attach(|py| { + let result = func.call1(py, (arity,)); + match result { + Ok(value) => value.extract::(py).unwrap_or(0), + Err(_) => 0, + } + }) + }; + + let arc: Arc u64 + Send + Sync> = Arc::new(closure); + IntFunction(qre::VariableArityFunction::generic_from_arc(arc)).into_bound_py_any(py) + } else { + let closure = move |arity: u64| -> f64 { + Python::attach(|py| { + let result = func.call1(py, (arity,)); + match result { + Ok(value) => value.extract::(py).unwrap_or(0.0), + Err(_) => 0.0, + } + }) + }; + + let arc: Arc f64 + Send + Sync> = Arc::new(closure); + FloatFunction(qre::VariableArityFunction::generic_from_arc(arc)).into_bound_py_any(py) + } +} + +#[derive(Default)] +#[pyclass(name = "_EstimationCollection")] +pub struct EstimationCollection(qre::EstimationCollection); + +#[pymethods] +impl EstimationCollection { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, result: &EstimationResult) { + self.0.insert(result.0.clone()); + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[getter] + pub fn total_jobs(&self) -> usize { + self.0.total_jobs() + } + + #[getter] + pub fn successful_estimates(&self) -> usize { + self.0.successful_estimates() + } + + /// Returns lightweight summaries of ALL successful estimates as a list + /// of (trace index, isa index, qubits, runtime) tuples. + #[getter] + pub fn all_summaries(&self) -> Vec<(usize, usize, u64, u64)> { + self.0 + .all_summaries() + .iter() + .map(|s| (s.trace_index, s.isa_index, s.qubits, s.runtime)) + .collect() + } + + /// Returns the set of ISAs for which this collection contains successful + /// estimates. + #[getter] + pub fn isas(&self) -> Vec { + self.0.isas().iter().cloned().map(ISA).collect() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = EstimationCollectionIterator { + iter: slf.0.iter().cloned().collect::>().into_iter(), + }; + Py::new(slf.py(), iter) + } +} + +#[pyclass] +pub struct EstimationCollectionIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl EstimationCollectionIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(EstimationResult) + } +} + +#[pyclass] +pub struct EstimationResult(qre::EstimationResult); + +#[pymethods] +impl EstimationResult { + #[new] + #[pyo3(signature = (*, qubits = 0, runtime = 0, error = 0.0))] + pub fn new(qubits: u64, runtime: u64, error: f64) -> Self { + let mut result = qre::EstimationResult::new(); + result.add_qubits(qubits); + result.add_runtime(runtime); + result.add_error(error); + + EstimationResult(result) + } + + #[getter] + pub fn qubits(&self) -> u64 { + self.0.qubits() + } + + #[setter] + pub fn set_qubits(&mut self, qubits: u64) { + self.0.set_qubits(qubits); + } + + #[getter] + pub fn runtime(&self) -> u64 { + self.0.runtime() + } + + #[setter] + pub fn set_runtime(&mut self, runtime: u64) { + self.0.set_runtime(runtime); + } + + #[getter] + pub fn error(&self) -> f64 { + self.0.error() + } + + #[setter] + pub fn set_error(&mut self, error: f64) { + self.0.set_error(error); + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn factories(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + + for (id, factory) in self_.0.factories() { + dict.set_item(id, FactoryResult(factory.clone()))?; + } + + Ok(dict) + } + + #[getter] + pub fn isa(&self) -> ISA { + ISA(self.0.isa().clone()) + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn properties(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + + for (key, value) in self_.0.properties() { + match value { + qre::Property::Bool(b) => dict.set_item(key, *b)?, + qre::Property::Int(i) => dict.set_item(key, *i)?, + qre::Property::Float(f) => dict.set_item(key, *f)?, + qre::Property::Str(s) => dict.set_item(key, s.clone())?, + } + } + + Ok(dict) + } + + pub fn set_property(&mut self, key: u64, value: &Bound<'_, PyAny>) -> PyResult<()> { + let property = if value.is_instance_of::() { + qre::Property::new_bool(value.extract()?) + } else if let Ok(i) = value.extract::() { + qre::Property::new_int(i) + } else if let Ok(f) = value.extract::() { + qre::Property::new_float(f) + } else { + qre::Property::new_str(value.to_string()) + }; + + self.0.set_property(key, property); + + Ok(()) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct FactoryResult(qre::FactoryResult); + +#[pymethods] +impl FactoryResult { + #[getter] + pub fn copies(&self) -> u64 { + self.0.copies() + } + + #[getter] + pub fn runs(&self) -> u64 { + self.0.runs() + } + + #[getter] + pub fn states(&self) -> u64 { + self.0.states() + } + + #[getter] + pub fn error_rate(&self) -> f64 { + self.0.error_rate() + } +} + +#[pyclass] +pub struct Trace(qre::Trace); + +#[pymethods] +impl Trace { + #[new] + pub fn new(compute_qubits: u64) -> Self { + Trace(qre::Trace::new(compute_qubits)) + } + + #[pyo3(signature = (compute_qubits = None))] + pub fn clone_empty(&self, compute_qubits: Option) -> Self { + Trace(self.0.clone_empty(compute_qubits)) + } + + #[classmethod] + pub fn from_json(_cls: &Bound<'_, PyType>, json: &str) -> PyResult { + let trace: qre::Trace = serde_json::from_str(json).map_err(|e| { + EstimationError::new_err(format!("Failed to parse trace from JSON: {e}")) + })?; + + Ok(Self(trace)) + } + + pub fn to_json(&self) -> PyResult { + serde_json::to_string(&self.0).map_err(|e| { + EstimationError::new_err(format!("Failed to serialize trace to JSON: {e}")) + }) + } + + #[getter] + pub fn compute_qubits(&self) -> u64 { + self.0.compute_qubits() + } + + #[setter] + pub fn set_compute_qubits(&mut self, qubits: u64) { + self.0.set_compute_qubits(qubits); + } + + #[getter] + pub fn base_error(&self) -> f64 { + self.0.base_error() + } + + pub fn increment_base_error(&mut self, amount: f64) { + self.0.increment_base_error(amount); + } + + pub fn set_property(&mut self, key: u64, value: &Bound<'_, PyAny>) -> PyResult<()> { + let property = if value.is_instance_of::() { + qre::Property::new_bool(value.extract()?) + } else if let Ok(i) = value.extract::() { + qre::Property::new_int(i) + } else if let Ok(f) = value.extract::() { + qre::Property::new_float(f) + } else { + qre::Property::new_str(value.to_string()) + }; + + self.0.set_property(key, property); + + Ok(()) + } + + #[allow(clippy::needless_pass_by_value)] + pub fn get_property(self_: PyRef<'_, Self>, key: u64) -> Option> { + if let Some(value) = self_.0.get_property(key) { + match value { + qre::Property::Bool(b) => PyBool::new(self_.py(), *b) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Int(i) => PyInt::new(self_.py(), *i) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Float(f) => PyFloat::new(self_.py(), *f) + .into_bound_py_any(self_.py()) + .ok(), + qre::Property::Str(s) => PyString::new(self_.py(), s) + .into_bound_py_any(self_.py()) + .ok(), + } + } else { + None + } + } + + pub fn has_property(&self, key: u64) -> bool { + self.0.has_property(key) + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn resource_states(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + if let Some(resource_states) = self_.0.get_resource_states() { + for (resource_id, count) in resource_states { + if *count != 0 { + dict.set_item(resource_id, *count)?; + } + } + } + Ok(dict) + } + + #[getter] + pub fn total_qubits(&self) -> u64 { + self.0.total_qubits() + } + + #[getter] + pub fn depth(&self) -> u64 { + self.0.depth() + } + + #[getter] + pub fn num_gates(&self) -> u64 { + self.0.num_gates() + } + + #[pyo3(signature = (isa, max_error = None))] + pub fn estimate(&self, isa: &ISA, max_error: Option) -> Option { + self.0 + .estimate(&isa.0, max_error) + .map(|mut r| { + r.set_isa(isa.0.clone()); + EstimationResult(r) + }) + .ok() + } + + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.0.add_operation(id, qubits, params); + } + + pub fn root_block(mut slf: PyRefMut<'_, Self>) -> Block { + let block = slf.0.root_block_mut(); + let ptr = NonNull::from(block); + Block { + ptr, + parent: slf.into(), + } + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(mut slf: PyRefMut<'_, Self>, repetitions: u64) -> Block { + let block = slf.0.add_block(repetitions); + let ptr = NonNull::from(block); + Block { + ptr, + parent: slf.into(), + } + } + + #[getter] + pub fn memory_qubits(&self) -> Option { + self.0.memory_qubits() + } + + pub fn has_memory_qubits(&self) -> bool { + self.0.has_memory_qubits() + } + + #[setter] + pub fn set_memory_qubits(&mut self, qubits: u64) { + self.0.set_memory_qubits(qubits); + } + + pub fn increment_memory_qubits(&mut self, amount: u64) { + self.0.increment_memory_qubits(amount); + } + + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { + self.0.increment_resource_state(resource_id, amount); + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } + + #[getter] + pub fn required_isa(&self) -> ISARequirements { + ISARequirements(self.0.required_instruction_ids(None)) + } +} + +#[pyclass(unsendable)] +pub struct Block { + ptr: NonNull, + #[allow(dead_code)] + parent: Py, +} + +#[pymethods] +impl Block { + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + unsafe { self.ptr.as_mut() }.add_operation(id, qubits, params); + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(&mut self, py: Python<'_>, repetitions: u64) -> PyResult { + let block = unsafe { self.ptr.as_mut() }.add_block(repetitions); + let ptr = NonNull::from(block); + Ok(Block { + ptr, + parent: self.parent.clone_ref(py), + }) + } + + fn __str__(&self) -> String { + format!("{}", unsafe { self.ptr.as_ref() }) + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyclass] +pub struct PSSPC(qre::PSSPC); + +#[pymethods] +impl PSSPC { + #[new] + pub fn new(num_ts_per_rotation: u64, ccx_magic_states: bool) -> Self { + PSSPC(qre::PSSPC::new(num_ts_per_rotation, ccx_magic_states)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +#[derive(Default)] +#[pyclass] +pub struct LatticeSurgery(qre::LatticeSurgery); + +#[pymethods] +impl LatticeSurgery { + #[new] + pub fn new(slow_down_factor: f64) -> Self { + Self(qre::LatticeSurgery::new(slow_down_factor)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +/// Dispatches a method call to the inner frontier variant, avoiding +/// repetitive match arms. Use as: +/// +/// ```ignore +/// dispatch_frontier!(self, f => f.len()) +/// dispatch_frontier!(mut self, f => f.insert(point.clone())) +/// ``` +macro_rules! dispatch_frontier { + ($self:ident, $f:ident => $body:expr) => { + match &$self.0 { + InstructionFrontierInner::Frontier2D($f) => $body, + InstructionFrontierInner::Frontier3D($f) => $body, + } + }; + (mut $self:ident, $f:ident => $body:expr) => { + match &mut $self.0 { + InstructionFrontierInner::Frontier2D($f) => $body, + InstructionFrontierInner::Frontier3D($f) => $body, + } + }; +} + +#[derive(Clone)] +enum InstructionFrontierInner { + Frontier2D(qre::ParetoFrontier2D), + Frontier3D(qre::ParetoFrontier3D), +} + +#[pyclass] +pub struct InstructionFrontier(InstructionFrontierInner); + +impl Default for InstructionFrontier { + fn default() -> Self { + Self(InstructionFrontierInner::Frontier3D( + qre::ParetoFrontier3D::new(), + )) + } +} + +#[pymethods] +impl InstructionFrontier { + #[new] + #[pyo3(signature = (*, with_error_objective = true))] + pub fn new(with_error_objective: bool) -> Self { + if with_error_objective { + Self(InstructionFrontierInner::Frontier3D( + qre::ParetoFrontier3D::new(), + )) + } else { + Self(InstructionFrontierInner::Frontier2D( + qre::ParetoFrontier2D::new(), + )) + } + } + + pub fn insert(&mut self, point: &Instruction) { + dispatch_frontier!(mut self, f => f.insert(point.clone())); + } + + #[allow(clippy::needless_pass_by_value)] + pub fn extend(&mut self, points: Vec>) { + dispatch_frontier!(mut self, f => f.extend(points.iter().map(|p| Instruction(p.0.clone())))); + } + + pub fn __len__(&self) -> usize { + dispatch_frontier!(self, f => f.len()) + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let items: Vec = dispatch_frontier!(slf, f => f.iter().cloned().collect()); + Py::new( + slf.py(), + InstructionFrontierIterator { + iter: items.into_iter(), + }, + ) + } + + #[staticmethod] + #[pyo3(signature = (filename, *, with_error_objective = true))] + pub fn load(filename: &str, with_error_objective: bool) -> PyResult { + let content = std::fs::read_to_string(filename)?; + let err = |e: serde_json::Error| EstimationError::new_err(format!("{e}")); + + let inner = if with_error_objective { + InstructionFrontierInner::Frontier3D(serde_json::from_str(&content).map_err(err)?) + } else { + InstructionFrontierInner::Frontier2D(serde_json::from_str(&content).map_err(err)?) + }; + Ok(InstructionFrontier(inner)) + } + + pub fn dump(&self, filename: &str) -> PyResult<()> { + let content = dispatch_frontier!(self, f => + serde_json::to_string(f).map_err(|e| EstimationError::new_err(format!("{e}")))? + ); + Ok(std::fs::write(filename, content)?) + } +} + +#[pyclass] +pub struct InstructionFrontierIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl InstructionFrontierIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next() + } +} + +#[allow(clippy::needless_pass_by_value)] +#[pyfunction(name = "_estimate_parallel", signature = (traces, isas, max_error = 1.0, post_process = false))] +pub fn estimate_parallel( + py: Python<'_>, + traces: Vec>, + isas: Vec>, + max_error: f64, + post_process: bool, +) -> EstimationCollection { + let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); + let isas: Vec<_> = isas.iter().map(|i| &i.0).collect(); + + // Release the GIL before entering the parallel section. + // Worker threads spawned by qre::estimate_parallel may need to acquire + // the GIL to evaluate Python callbacks (via generic_function closures). + // If the calling thread holds the GIL while blocked in + // std::thread::scope, the worker threads deadlock. + let collection = release_gil(py, || { + qre::estimate_parallel(&traces, &isas, Some(max_error), post_process) + }); + EstimationCollection(collection) +} + +#[allow(clippy::needless_pass_by_value)] +#[pyfunction(name = "_estimate_with_graph", signature = (traces, graph, max_error = 1.0, post_process = false))] +pub fn estimate_with_graph( + py: Python<'_>, + traces: Vec>, + graph: &ProvenanceGraph, + max_error: f64, + post_process: bool, +) -> PyResult { + let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); + + let collection = release_gil(py, || { + qre::estimate_with_graph(&traces, &graph.0, Some(max_error), post_process) + }); + Ok(EstimationCollection(collection)) +} + +/// Releases the GIL for the duration of the closure `f`, allowing other +/// threads to acquire it. Delegates to `py.detach()` so that pyo3's internal +/// attach-count is properly reset; this ensures that any `Python::attach` +/// calls inside `f` (e.g. from `generic_function` callbacks) will correctly +/// call `PyGILState_Ensure` to re-acquire the GIL. +/// +/// The caller must ensure that no `Bound<'_, _>` or `Python<'_>` references +/// are used inside `f`. GIL-independent `Py` handles are fine because +/// they re-acquire the GIL via `Python::attach` when needed. +fn release_gil(py: Python<'_>, f: F) -> R +where + F: FnOnce() -> R + Send, + R: Send, +{ + py.detach(f) +} + +#[pyfunction(name = "_binom_ppf")] +pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { + qre::binom_ppf(q, n, p) +} + +#[pyfunction(name = "_float_to_bits")] +pub fn float_to_bits(f: f64) -> u64 { + qre::float_to_bits(f) +} + +#[pyfunction(name = "_float_from_bits")] +pub fn float_from_bits(bits: u64) -> f64 { + qre::float_from_bits(bits) +} + +#[pyfunction] +pub fn instruction_name(id: u64) -> Option { + qre::instruction_name(id).map(String::from) +} + +fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { + #[allow(clippy::wildcard_imports)] + use qre::instruction_ids::*; + + let instruction_ids = PyModule::new(m.py(), "instruction_ids")?; + + macro_rules! add_ids { + ($($name:ident),* $(,)?) => { + $(instruction_ids.add(stringify!($name), $name)?;)* + }; + } + + add_ids!( + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + H, + H_XZ, + H_XY, + H_YZ, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + S, + SQRT_Z, + S_DAG, + SQRT_Z_DAG, + CNOT, + CX, + CY, + CZ, + SWAP, + PREP_X, + PREP_Y, + PREP_Z, + ONE_QUBIT_CLIFFORD, + TWO_QUBIT_CLIFFORD, + N_QUBIT_CLIFFORD, + MEAS_X, + MEAS_Y, + MEAS_Z, + MEAS_RESET_X, + MEAS_RESET_Y, + MEAS_RESET_Z, + MEAS_XX, + MEAS_YY, + MEAS_ZZ, + MEAS_XZ, + MEAS_XY, + MEAS_YZ, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + T, + SQRT_SQRT_Z_DAG, + T_DAG, + CCX, + CCY, + CCZ, + CSWAP, + AND, + AND_DAG, + RX, + RY, + RZ, + CRX, + CRY, + CRZ, + RXX, + RYY, + RZZ, + ONE_QUBIT_UNITARY, + TWO_QUBIT_UNITARY, + MULTI_PAULI_MEAS, + LATTICE_SURGERY, + READ_FROM_MEMORY, + WRITE_TO_MEMORY, + MEMORY, + CYCLIC_SHIFT, + GENERIC + ); + + m.add_submodule(&instruction_ids)?; + + Ok(()) +} + +#[pyfunction] +pub fn property_name_to_key(name: &str) -> Option { + qre::property_name_to_key(&name.to_ascii_uppercase()) +} + +#[pyfunction] +pub fn property_name(id: u64) -> Option { + qre::property_name(id).map(String::from) +} + +fn add_property_keys(m: &Bound<'_, PyModule>) -> PyResult<()> { + #[allow(clippy::wildcard_imports)] + use qre::property_keys::*; + + let property_keys = PyModule::new(m.py(), "property_keys")?; + + macro_rules! add_ids { + ($($name:ident),* $(,)?) => { + $(property_keys.add(stringify!($name), $name)?;)* + }; + } + + add_ids!( + DISTANCE, + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, + ACCELERATION, + NUM_TS_PER_ROTATION, + EXPECTED_SHOTS, + RUNTIME_SINGLE_SHOT, + EVALUATION_TIME, + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, + PHYSICAL_MEMORY_QUBITS, + MOLECULE, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + NAME + ); + + m.add_submodule(&property_keys)?; + + Ok(()) +} diff --git a/source/pip/test_requirements.txt b/source/pip/test_requirements.txt index ae3f808636..a0f74006aa 100644 --- a/source/pip/test_requirements.txt +++ b/source/pip/test_requirements.txt @@ -1,3 +1,5 @@ pytest expecttest==0.3.0 pyqir>=0.11.1,<0.12 +cirq==1.6.1 +pandas>=2.1 diff --git a/source/pip/tests/magnets/__init__.py b/source/pip/tests/magnets/__init__.py new file mode 100644 index 0000000000..4540e70bc2 --- /dev/null +++ b/source/pip/tests/magnets/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the magnets library.""" + +try: + # pylint: disable=unused-import + # flake8: noqa E401 + import cirq + + CIRQ_AVAILABLE = True +except ImportError: + CIRQ_AVAILABLE = False + +SKIP_REASON = "cirq is not available" diff --git a/source/pip/tests/magnets/test_complete.py b/source/pip/tests/magnets/test_complete.py new file mode 100644 index 0000000000..38052dc668 --- /dev/null +++ b/source/pip/tests/magnets/test_complete.py @@ -0,0 +1,255 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for complete graph data structures.""" + +from qsharp.magnets.geometry.complete import ( + CompleteBipartiteGraph, + CompleteGraph, +) +from qsharp.magnets.utilities import HypergraphEdgeColoring + + +# CompleteGraph tests + + +def test_complete_graph_init_basic(): + """Test basic CompleteGraph initialization.""" + graph = CompleteGraph(4) + assert graph.nvertices == 4 + assert graph.nedges == 6 # 4 * 3 / 2 = 6 + assert graph.n == 4 + + +def test_complete_graph_single_vertex(): + """Test CompleteGraph with a single vertex (no edges).""" + graph = CompleteGraph(1) + assert graph.nvertices == 0 + assert graph.nedges == 0 + assert graph.n == 1 + + +def test_complete_graph_two_vertices(): + """Test CompleteGraph with two vertices (one edge).""" + graph = CompleteGraph(2) + assert graph.nvertices == 2 + assert graph.nedges == 1 + + +def test_complete_graph_three_vertices(): + """Test CompleteGraph with three vertices (triangle).""" + graph = CompleteGraph(3) + assert graph.nvertices == 3 + assert graph.nedges == 3 + + +def test_complete_graph_five_vertices(): + """Test CompleteGraph with five vertices.""" + graph = CompleteGraph(5) + assert graph.nvertices == 5 + assert graph.nedges == 10 # 5 * 4 / 2 = 10 + + +def test_complete_graph_edges(): + """Test that CompleteGraph creates correct edges.""" + graph = CompleteGraph(4) + edges = list(graph.edges()) + assert len(edges) == 6 + # All pairs should be present + edge_sets = [set(e.vertices) for e in edges] + assert {0, 1} in edge_sets + assert {0, 2} in edge_sets + assert {0, 3} in edge_sets + assert {1, 2} in edge_sets + assert {1, 3} in edge_sets + assert {2, 3} in edge_sets + + +def test_complete_graph_vertices(): + """Test that CompleteGraph vertices are correct.""" + graph = CompleteGraph(5) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_complete_graph_with_self_loops(): + """Test CompleteGraph with self-loops enabled.""" + graph = CompleteGraph(4, self_loops=True) + assert graph.nvertices == 4 + # 4 self-loops + 6 edges = 10 + assert graph.nedges == 10 + + +def test_complete_graph_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + graph = CompleteGraph(3, self_loops=True) + edge_vertices = {edge.vertices for edge in graph.edges()} + assert {(0,), (1,), (2,)}.issubset(edge_vertices) + + +def test_complete_graph_edge_count_formula(): + """Test that edge count follows n(n-1)/2 formula.""" + for n in range(1, 10): + graph = CompleteGraph(n) + expected_edges = n * (n - 1) // 2 + assert graph.nedges == expected_edges + + +def test_complete_graph_str(): + """Test string representation.""" + graph = CompleteGraph(4) + assert "4 vertices" in str(graph) + assert "6 edges" in str(graph) + + +def test_complete_graph_inherits_hypergraph(): + """Test that CompleteGraph is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + graph = CompleteGraph(4) + assert isinstance(graph, Hypergraph) + assert hasattr(graph, "edges") + assert hasattr(graph, "vertices") + + +# CompleteBipartiteGraph tests + + +def test_complete_bipartite_graph_init_basic(): + """Test basic CompleteBipartiteGraph initialization.""" + graph = CompleteBipartiteGraph(2, 3) + assert graph.nvertices == 5 + assert graph.nedges == 6 # 2 * 3 = 6 + assert graph.m == 2 + assert graph.n == 3 + + +def test_complete_bipartite_graph_single_each(): + """Test CompleteBipartiteGraph with one vertex in each set.""" + graph = CompleteBipartiteGraph(1, 1) + assert graph.nvertices == 2 + assert graph.nedges == 1 + + +def test_complete_bipartite_graph_one_and_many(): + """Test CompleteBipartiteGraph with one vertex in first set.""" + graph = CompleteBipartiteGraph(1, 5) + assert graph.nvertices == 6 + assert graph.nedges == 5 # 1 * 5 = 5 + + +def test_complete_bipartite_graph_square(): + """Test CompleteBipartiteGraph with equal set sizes.""" + graph = CompleteBipartiteGraph(3, 3) + assert graph.nvertices == 6 + assert graph.nedges == 9 # 3 * 3 = 9 + + +def test_complete_bipartite_graph_edges(): + """Test that CompleteBipartiteGraph creates correct edges.""" + graph = CompleteBipartiteGraph(2, 3) + edges = list(graph.edges()) + assert len(edges) == 6 + # Vertices 0, 1 in first set; 2, 3, 4 in second set + edge_sets = [set(e.vertices) for e in edges] + # All pairs between sets should be present + assert {0, 2} in edge_sets + assert {0, 3} in edge_sets + assert {0, 4} in edge_sets + assert {1, 2} in edge_sets + assert {1, 3} in edge_sets + assert {1, 4} in edge_sets + # No edges within sets + assert {0, 1} not in edge_sets + assert {2, 3} not in edge_sets + assert {2, 4} not in edge_sets + assert {3, 4} not in edge_sets + + +def test_complete_bipartite_graph_vertices(): + """Test that CompleteBipartiteGraph vertices are correct.""" + graph = CompleteBipartiteGraph(2, 3) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_complete_bipartite_graph_with_self_loops(): + """Test CompleteBipartiteGraph with self-loops enabled.""" + graph = CompleteBipartiteGraph(2, 3, self_loops=True) + assert graph.nvertices == 5 + # 5 self-loops + 6 edges = 11 + assert graph.nedges == 11 + + +def test_complete_bipartite_graph_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + graph = CompleteBipartiteGraph(2, 2, self_loops=True) + edge_vertices = {edge.vertices for edge in graph.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) + + +def test_complete_bipartite_graph_edge_count_formula(): + """Test that edge count follows m * n formula.""" + for m in range(1, 6): + for n in range(m, 6): + graph = CompleteBipartiteGraph(m, n) + expected_edges = m * n + assert graph.nedges == expected_edges + + +def test_complete_bipartite_graph_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" + graph = CompleteBipartiteGraph(3, 4) + coloring = graph.edge_coloring() + # Should have n colors for bipartite coloring + assert coloring.ncolors == 4 + + +def test_complete_bipartite_graph_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" + graph = CompleteBipartiteGraph(3, 4, self_loops=True) + coloring = graph.edge_coloring() + # Self-loops get color -1, bipartite edges get n colors (0 to n-1) + # ncolors counts nonnegative colors only. + assert coloring.ncolors == 4 + + +def test_complete_bipartite_graph_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" + graph = CompleteBipartiteGraph(3, 4) + coloring = graph.edge_coloring() + # Group edges by color + colors = {} + for edge in graph.edges(): + color = coloring.color(edge.vertices) + assert color is not None + edge_vertices = edge.vertices + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_complete_bipartite_graph_str(): + """Test string representation.""" + graph = CompleteBipartiteGraph(2, 3) + assert "5 vertices" in str(graph) + assert "6 edges" in str(graph) + + +def test_complete_bipartite_graph_inherits_hypergraph(): + """Test that CompleteBipartiteGraph is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + graph = CompleteBipartiteGraph(2, 3) + assert isinstance(graph, Hypergraph) + assert hasattr(graph, "edges") + assert hasattr(graph, "vertices") + coloring = graph.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_hypergraph.py b/source/pip/tests/magnets/test_hypergraph.py new file mode 100755 index 0000000000..2c28289824 --- /dev/null +++ b/source/pip/tests/magnets/test_hypergraph.py @@ -0,0 +1,434 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for hypergraph data structures.""" + +import pytest + +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + HypergraphEdgeColoring, +) + + +# Hyperedge tests + + +def test_hyperedge_init_basic(): + """Test basic Hyperedge initialization.""" + edge = Hyperedge([0, 1]) + assert edge.vertices == (0, 1) + + +def test_hyperedge_vertices_sorted(): + """Test that vertices are automatically sorted.""" + edge = Hyperedge([3, 1, 2]) + assert edge.vertices == (1, 2, 3) + + +def test_hyperedge_single_vertex(): + """Test hyperedge with single vertex (self-loop).""" + edge = Hyperedge([5]) + assert edge.vertices == (5,) + assert len(edge.vertices) == 1 + + +def test_hyperedge_multiple_vertices(): + """Test hyperedge with multiple vertices (multi-body interaction).""" + edge = Hyperedge([0, 1, 2, 3]) + assert edge.vertices == (0, 1, 2, 3) + assert len(edge.vertices) == 4 + + +def test_hyperedge_repr(): + """Test string representation.""" + edge = Hyperedge([1, 0]) + assert repr(edge) == "Hyperedge([0, 1])" + + +def test_hyperedge_empty_vertices(): + """Test hyperedge with empty vertex list.""" + edge = Hyperedge([]) + assert edge.vertices == () + assert len(edge.vertices) == 0 + + +def test_hyperedge_duplicate_vertices(): + """Test that duplicate vertices are removed.""" + edge = Hyperedge([1, 2, 2, 1, 3]) + assert edge.vertices == (1, 2, 3) + + +# Hypergraph tests + + +def test_hypergraph_init_basic(): + """Test basic Hypergraph initialization.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + assert graph.nedges == 2 + assert graph.nvertices == 3 + + +def test_hypergraph_empty_graph(): + """Test hypergraph with no edges.""" + graph = Hypergraph([]) + assert graph.nedges == 0 + assert graph.nvertices == 0 + + +def test_hypergraph_nedges(): + """Test edge count.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + assert graph.nedges == 3 + + +def test_hypergraph_nvertices(): + """Test vertex count with unique vertices.""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + assert graph.nvertices == 4 + + +def test_hypergraph_nvertices_with_shared_vertices(): + """Test vertex count when edges share vertices.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + assert graph.nvertices == 3 + + +def test_hypergraph_vertices_iterator(): + """Test vertices iterator returns sorted vertices.""" + edges = [Hyperedge([3, 1]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + vertices = list(graph.vertices()) + assert vertices == [0, 1, 2, 3] + + +def test_hypergraph_edges_iterator(): + """Test edges iterator returns all edges.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + edge_list = list(graph.edges()) + assert len(edge_list) == 2 + + +def test_hypergraph_edges_of_color(): + """Test HypergraphEdgeColoring returns edges with a specific color.""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + coloring = HypergraphEdgeColoring(graph) + coloring.add_edge(edges[0], 0) + coloring.add_edge(edges[1], 0) + edge_list = list(coloring.edges_of_color(0)) + assert len(edge_list) == 2 + + +def test_hypergraph_edge_coloring_method_returns_coloring(): + """Test Hypergraph.edge_coloring returns a HypergraphEdgeColoring.""" + graph = Hypergraph([Hyperedge([0, 1]), Hyperedge([2, 3])]) + coloring = graph.edge_coloring(seed=42) + assert isinstance(coloring, HypergraphEdgeColoring) + + +def test_hypergraph_edge_coloring_stores_supporting_hypergraph(): + """Test HypergraphEdgeColoring keeps a reference to its Hypergraph.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + assert coloring.hypergraph is graph + + +def test_hypergraph_edge_coloring_rejects_non_hyperedge(): + """Test add_edge rejects non-Hyperedge values.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises(TypeError, match="edge must be Hyperedge"): + coloring.add_edge((0, 1), 0) + + +def test_hypergraph_edge_coloring_rejects_edge_not_in_hypergraph(): + """Test add_edge rejects Hyperedge values not in the supporting Hypergraph.""" + graph = Hypergraph([Hyperedge([0, 1])]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="edge must belong to the supporting Hypergraph" + ): + coloring.add_edge(Hyperedge([1, 2]), 0) + + +def test_hypergraph_edge_coloring_rejects_equivalent_edge_not_in_hypergraph(): + """Test add_edge requires an edge instance from the supporting Hypergraph.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="edge must belong to the supporting Hypergraph" + ): + coloring.add_edge(Hyperedge([0, 1]), 0) + + +def test_hypergraph_edge_coloring_color_matches_equivalent_vertices(): + """Test color lookup uses edge vertex tuples as keys.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + coloring.add_edge(edge, 3) + assert coloring.color((0, 1)) == 3 + + +def test_hypergraph_edge_coloring_rejects_negative_color_for_nontrivial_edge(): + """Test add_edge raises ValueError for negative color on nontrivial edges.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + coloring = HypergraphEdgeColoring(graph) + + with pytest.raises( + ValueError, match="Color index must be nonnegative for multi-vertex edges" + ): + coloring.add_edge(edge, -1) + + +def test_hypergraph_edge_coloring_rejects_conflicting_edge(): + """Test add_edge raises RuntimeError when same-color edges share a vertex.""" + edge1 = Hyperedge([0, 1]) + edge2 = Hyperedge([1, 2]) + graph = Hypergraph([edge1, edge2]) + coloring = HypergraphEdgeColoring(graph) + + coloring.add_edge(edge1, 0) + + with pytest.raises( + RuntimeError, match="Edge conflicts with existing edge of same color" + ): + coloring.add_edge(edge2, 0) + + +def test_hypergraph_add_edge(): + """Test adding an edge to the hypergraph.""" + graph = Hypergraph([]) + graph.add_edge(Hyperedge([0, 1])) + assert graph.nedges == 1 + assert graph.nvertices == 2 + + +def test_hypergraph_add_edge_with_color(): + """Test assigning colors via HypergraphEdgeColoring.""" + graph = Hypergraph([Hyperedge([0, 1])]) + edge = Hyperedge([2, 3]) + graph.add_edge(edge) + coloring = HypergraphEdgeColoring(graph) + coloring.add_edge(edge, color=1) + assert graph.nedges == 2 + assert coloring.color(edge.vertices) == 1 + + +def test_hypergraph_color_default(): + """Test that Hypergraph has no built-in color mapping.""" + graph = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])]) + assert not hasattr(graph, "color") + + +def test_hypergraph_str(): + """Test string representation.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + expected = "Hypergraph with 4 vertices and 3 edges." + assert str(graph) == expected + + +def test_hypergraph_repr(): + """Test repr representation.""" + edges = [Hyperedge([0, 1])] + graph = Hypergraph(edges) + result = repr(graph) + assert "Hypergraph" in result + assert "Hyperedge" in result + + +def test_hypergraph_single_vertex_edges(): + """Test hypergraph with self-loop edges.""" + edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] + graph = Hypergraph(edges) + assert graph.nedges == 3 + assert graph.nvertices == 3 + + +def test_hypergraph_mixed_edge_sizes(): + """Test hypergraph with edges of different sizes.""" + edges = [ + Hyperedge([0]), # 1 vertex (self-loop) + Hyperedge([1, 2]), # 2 vertices (pair) + Hyperedge([3, 4, 5]), # 3 vertices (triple) + ] + graph = Hypergraph(edges) + assert graph.nedges == 3 + assert graph.nvertices == 6 + + +def test_hypergraph_non_contiguous_vertices(): + """Test hypergraph with non-contiguous vertex indices.""" + edges = [Hyperedge([0, 10]), Hyperedge([5, 20])] + graph = Hypergraph(edges) + assert graph.nvertices == 4 + vertices = list(graph.vertices()) + assert vertices == [0, 5, 10, 20] + + +# greedyEdgeColoring tests + + +def test_greedy_edge_coloring_empty(): + """Test greedy edge coloring on empty hypergraph.""" + graph = Hypergraph([]) + colored = graph.edge_coloring() + assert isinstance(colored, HypergraphEdgeColoring) + assert len(list(colored.colors())) == 0 + assert colored.ncolors == 0 + + +def test_greedy_edge_coloring_single_edge(): + """Test greedy edge coloring with a single edge.""" + edge = Hyperedge([0, 1]) + graph = Hypergraph([edge]) + colored = graph.edge_coloring(seed=42) + assert colored.color(edge.vertices) == 0 + assert colored.ncolors == 1 + + +def test_greedy_edge_coloring_non_overlapping(): + """Test coloring of non-overlapping edges (can share color).""" + edges = [Hyperedge([0, 1]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + # Non-overlapping edges can be in the same color + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.ncolors == 1 + + +def test_greedy_edge_coloring_overlapping(): + """Test coloring of overlapping edges (need different colors).""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2])] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + # Overlapping edges need different colors + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.ncolors == 2 + + +def test_greedy_edge_coloring_triangle(): + """Test coloring of a triangle (3 edges, all pairwise overlapping).""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([0, 2])] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + # All edges share vertices pairwise, so need 3 colors + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None + assert colored.ncolors == 3 + + +def test_greedy_edge_coloring_validity(): + """Test that coloring is valid (no two edges with same color share a vertex).""" + edges = [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([2, 3]), + Hyperedge([3, 4]), + Hyperedge([0, 4]), + ] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + + # Group edges by color + colors = {} + for edge in edges: + color = colored.color(edge.vertices) + assert color is not None + if color not in colors: + colors[color] = [] + colors[color].append(edge.vertices) + + # Verify each color group has no overlapping edges + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + # No vertex should already be used in this color + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_greedy_edge_coloring_all_edges_colored(): + """Test that all edges are assigned a color.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3])] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + + # All edges should have a color assigned + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None + + +def test_greedy_edge_coloring_reproducible_with_seed(): + """Test that coloring is reproducible with the same seed.""" + edges = [Hyperedge([0, 1]), Hyperedge([1, 2]), Hyperedge([2, 3]), Hyperedge([0, 3])] + graph = Hypergraph(edges) + + colored1 = graph.edge_coloring(seed=123) + colored2 = graph.edge_coloring(seed=123) + + color_map_1 = {edge.vertices: colored1.color(edge.vertices) for edge in edges} + color_map_2 = {edge.vertices: colored2.color(edge.vertices) for edge in edges} + assert color_map_1 == color_map_2 + + +def test_greedy_edge_coloring_multiple_trials(): + """Test that multiple trials can find better colorings.""" + edges = [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([2, 3]), + Hyperedge([3, 0]), + ] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42, trials=10) + # A cycle of 4 edges can be 2-colored + assert colored.ncolors <= 3 # Greedy may not always find optimal + + +def test_greedy_edge_coloring_hyperedges(): + """Test coloring with multi-vertex hyperedges.""" + edges = [ + Hyperedge([0, 1, 2]), + Hyperedge([2, 3, 4]), + Hyperedge([5, 6, 7]), + ] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + + # First two share vertex 2, third is independent + assert colored.color(edges[0].vertices) is not None + assert colored.color(edges[1].vertices) is not None + assert colored.color(edges[2].vertices) is not None + assert colored.ncolors >= 2 + + +def test_greedy_edge_coloring_self_loops(): + """Test coloring with self-loop edges.""" + edges = [Hyperedge([0]), Hyperedge([1]), Hyperedge([2])] + graph = Hypergraph(edges) + colored = graph.edge_coloring(seed=42) + + # Self-loops use the special -1 color and do not contribute to ncolors. + assert colored.color(edges[0].vertices) == -1 + assert colored.color(edges[1].vertices) == -1 + assert colored.color(edges[2].vertices) == -1 + assert colored.ncolors == 0 diff --git a/source/pip/tests/magnets/test_lattice1d.py b/source/pip/tests/magnets/test_lattice1d.py new file mode 100644 index 0000000000..8117ee3617 --- /dev/null +++ b/source/pip/tests/magnets/test_lattice1d.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for 1D lattice data structures.""" + +from qsharp.magnets.geometry.lattice1d import Chain1D, Ring1D +from qsharp.magnets.utilities import Hyperedge, HypergraphEdgeColoring + + +def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: + coloring = graph.edge_coloring() + return {edge.vertices: coloring.color(edge.vertices) for edge in graph.edges()} + + +# Chain1D tests + + +def test_chain1d_init_basic(): + """Test basic Chain1D initialization.""" + chain = Chain1D(4) + assert chain.nvertices == 4 + assert chain.nedges == 3 + assert chain.length == 4 + + +def test_chain1d_single_vertex(): + """Test Chain1D with a single vertex (no edges).""" + chain = Chain1D(1) + assert chain.nvertices == 0 + assert chain.nedges == 0 + assert chain.length == 1 + + +def test_chain1d_two_vertices(): + """Test Chain1D with two vertices (one edge).""" + chain = Chain1D(2) + assert chain.nvertices == 2 + assert chain.nedges == 1 + + +def test_chain1d_edges(): + """Test that Chain1D creates correct nearest-neighbor edges.""" + chain = Chain1D(4) + edge_vertices = {edge.vertices for edge in chain.edges()} + assert edge_vertices == {(0, 1), (1, 2), (2, 3)} + + +def test_chain1d_vertices(): + """Test that Chain1D vertices are correct.""" + chain = Chain1D(5) + vertices = list(chain.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_chain1d_with_self_loops(): + """Test Chain1D with self-loops enabled.""" + chain = Chain1D(4, self_loops=True) + assert chain.nvertices == 4 + # 4 self-loops + 3 nearest-neighbor edges = 7 + assert chain.nedges == 7 + + +def test_chain1d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + chain = Chain1D(3, self_loops=True) + edge_vertices = {edge.vertices for edge in chain.edges()} + assert edge_vertices == {(0,), (1,), (2,), (0, 1), (1, 2)} + + +def test_chain1d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" + chain = Chain1D(5) + color = _vertex_color_map(chain) + # Even edges (0-1, 2-3) should have color 0 + assert color[(0, 1)] == 0 + assert color[(2, 3)] == 0 + # Odd edges (1-2, 3-4) should have color 1 + assert color[(1, 2)] == 1 + assert color[(3, 4)] == 1 + + +def test_chain1d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" + chain = Chain1D(4, self_loops=True) + color = _vertex_color_map(chain) + # Self-loops should have color -1 + assert color[(0,)] == -1 + assert color[(1,)] == -1 + assert color[(2,)] == -1 + assert color[(3,)] == -1 + # Even edges should have color 0, odd edges color 1 + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 + + +def test_chain1d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" + chain = Chain1D(6) + coloring = chain.edge_coloring() + # Group edges by color + colors = {} + for edge in chain.edges(): + color = coloring.color(edge.vertices) + assert color is not None + edge_vertices = edge.vertices + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_chain1d_str(): + """Test string representation.""" + chain = Chain1D(4) + assert "4 vertices" in str(chain) + assert "3 edges" in str(chain) + + +# Ring1D tests + + +def test_ring1d_init_basic(): + """Test basic Ring1D initialization.""" + ring = Ring1D(4) + assert ring.nvertices == 4 + assert ring.nedges == 4 + assert ring.length == 4 + + +def test_ring1d_two_vertices(): + """Test Ring1D with two vertices (two edges, same pair).""" + ring = Ring1D(2) + assert ring.nvertices == 2 + # Edge 0-1 and edge 1-0 (wrapping), but both are [0,1] after sorting + assert ring.nedges == 2 + + +def test_ring1d_three_vertices(): + """Test Ring1D with three vertices (triangle).""" + ring = Ring1D(3) + assert ring.nvertices == 3 + assert ring.nedges == 3 + + +def test_ring1d_edges(): + """Test that Ring1D creates correct edges including wrap-around.""" + ring = Ring1D(4) + edge_vertices = {edge.vertices for edge in ring.edges()} + assert edge_vertices == {(0, 1), (1, 2), (2, 3), (0, 3)} + + +def test_ring1d_vertices(): + """Test that Ring1D vertices are correct.""" + ring = Ring1D(5) + vertices = list(ring.vertices()) + assert vertices == [0, 1, 2, 3, 4] + + +def test_ring1d_with_self_loops(): + """Test Ring1D with self-loops enabled.""" + ring = Ring1D(4, self_loops=True) + assert ring.nvertices == 4 + # 4 self-loops + 4 nearest-neighbor edges = 8 + assert ring.nedges == 8 + + +def test_ring1d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + ring = Ring1D(3, self_loops=True) + edge_vertices = {edge.vertices for edge in ring.edges()} + assert edge_vertices == {(0,), (1,), (2,), (0, 1), (1, 2), (0, 2)} + + +def test_ring1d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" + ring = Ring1D(4) + color = _vertex_color_map(ring) + # Even edges should have color 0, odd edges color 1 + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 + assert color[(0, 3)] == 1 # Wrap-around edge + + +def test_ring1d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" + ring = Ring1D(4, self_loops=True) + color = _vertex_color_map(ring) + # Self-loops should have color -1 + assert color[(0,)] == -1 + assert color[(1,)] == -1 + assert color[(2,)] == -1 + assert color[(3,)] == -1 + # Even edges should have color 0, odd edges color 1 + assert color[(0, 1)] == 0 + assert color[(1, 2)] == 1 + assert color[(2, 3)] == 0 + assert color[(0, 3)] == 1 + + +def test_ring1d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" + ring = Ring1D(6) + coloring = ring.edge_coloring() + # Group edges by color + colors = {} + for edge in ring.edges(): + color = coloring.color(edge.vertices) + assert color is not None + edge_vertices = edge.vertices + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_ring1d_str(): + """Test string representation.""" + ring = Ring1D(4) + assert "4 vertices" in str(ring) + assert "4 edges" in str(ring) + + +def test_ring1d_vs_chain1d_edge_count(): + """Test that ring has one more edge than chain of same length.""" + for length in range(2, 10): + chain = Chain1D(length) + ring = Ring1D(length) + assert ring.nedges == chain.nedges + 1 + + +def test_chain1d_inherits_hypergraph(): + """Test that Chain1D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + chain = Chain1D(4) + assert isinstance(chain, Hypergraph) + # Test inherited methods work + assert hasattr(chain, "edges") + assert hasattr(chain, "vertices") + coloring = chain.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") + + +def test_ring1d_inherits_hypergraph(): + """Test that Ring1D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + ring = Ring1D(4) + assert isinstance(ring, Hypergraph) + # Test inherited methods work + assert hasattr(ring, "edges") + assert hasattr(ring, "vertices") + coloring = ring.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_lattice2d.py b/source/pip/tests/magnets/test_lattice2d.py new file mode 100644 index 0000000000..6a1291e9b4 --- /dev/null +++ b/source/pip/tests/magnets/test_lattice2d.py @@ -0,0 +1,301 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for 2D lattice data structures.""" + +from qsharp.magnets.geometry.lattice2d import Patch2D, Torus2D +from qsharp.magnets.utilities import HypergraphEdgeColoring + + +def _vertex_color_map(graph) -> dict[tuple[int, ...], int | None]: + coloring = graph.edge_coloring() + return {edge.vertices: coloring.color(edge.vertices) for edge in graph.edges()} + + +# Patch2D tests + + +def test_patch2d_init_basic(): + """Test basic Patch2D initialization.""" + patch = Patch2D(3, 2) + assert patch.nvertices == 6 + # 2 horizontal edges per row * 2 rows + 3 vertical edges per column * 1 = 7 + assert patch.nedges == 7 + assert patch.width == 3 + assert patch.height == 2 + + +def test_patch2d_single_vertex(): + """Test Patch2D with a single vertex (no edges).""" + patch = Patch2D(1, 1) + assert patch.nvertices == 0 + assert patch.nedges == 0 + assert patch.width == 1 + assert patch.height == 1 + + +def test_patch2d_single_row(): + """Test Patch2D with a single row (like Chain1D).""" + patch = Patch2D(4, 1) + assert patch.nvertices == 4 + assert patch.nedges == 3 # Only horizontal edges + + +def test_patch2d_single_column(): + """Test Patch2D with a single column.""" + patch = Patch2D(1, 4) + assert patch.nvertices == 4 + assert patch.nedges == 3 # Only vertical edges + + +def test_patch2d_square(): + """Test Patch2D with a square lattice.""" + patch = Patch2D(3, 3) + assert patch.nvertices == 9 + # 2 horizontal * 3 rows + 3 vertical * 2 = 12 + assert patch.nedges == 12 + + +def test_patch2d_edges(): + """Test that Patch2D creates correct nearest-neighbor edges.""" + patch = Patch2D(2, 2) + edges = list(patch.edges()) + # Should have 4 edges: 2 horizontal + 2 vertical + assert len(edges) == 4 + # Vertices: 0=(0,0), 1=(1,0), 2=(0,1), 3=(1,1) + # Horizontal: [0,1], [2,3] + # Vertical: [0,2], [1,3] + edge_sets = [set(e.vertices) for e in edges] + assert {0, 1} in edge_sets + assert {2, 3} in edge_sets + assert {0, 2} in edge_sets + assert {1, 3} in edge_sets + + +def test_patch2d_vertices(): + """Test that Patch2D vertices are correct.""" + patch = Patch2D(3, 2) + vertices = list(patch.vertices()) + assert vertices == [0, 1, 2, 3, 4, 5] + + +def test_patch2d_with_self_loops(): + """Test Patch2D with self-loops enabled.""" + patch = Patch2D(3, 2, self_loops=True) + assert patch.nvertices == 6 + # 6 self-loops + 7 nearest-neighbor edges = 13 + assert patch.nedges == 13 + + +def test_patch2d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + patch = Patch2D(2, 2, self_loops=True) + edge_vertices = {edge.vertices for edge in patch.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) + + +def test_patch2d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" + patch = Patch2D(4, 4) + coloring = patch.edge_coloring() + # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) + assert coloring.ncolors == 4 + + +def test_patch2d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" + patch = Patch2D(3, 3, self_loops=True) + coloring = patch.edge_coloring() + # Self-loops are -1 and do not contribute to ncolors. + assert coloring.ncolors == 4 + + +def test_patch2d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" + patch = Patch2D(4, 4) + coloring = patch.edge_coloring() + # Group edges by color + colors = {} + for edge in patch.edges(): + color = coloring.color(edge.vertices) + assert color is not None + edge_vertices = edge.vertices + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_patch2d_str(): + """Test string representation.""" + patch = Patch2D(3, 2) + assert str(patch) == "3x2 lattice patch with 6 vertices and 7 edges" + + +# Torus2D tests + + +def test_torus2d_init_basic(): + """Test basic Torus2D initialization.""" + torus = Torus2D(3, 2) + assert torus.nvertices == 6 + # 3 horizontal edges per row * 2 rows + 3 vertical edges per column * 2 = 12 + assert torus.nedges == 12 + assert torus.width == 3 + assert torus.height == 2 + + +def test_torus2d_single_vertex(): + """Test Torus2D with a single vertex (self-edge in both directions).""" + torus = Torus2D(1, 1) + assert torus.nvertices == 1 + # One horizontal wrap + one vertical wrap, both connect vertex 0 to itself + assert torus.nedges == 2 + + +def test_torus2d_single_row(): + """Test Torus2D with a single row (like Ring1D + vertical wraps).""" + torus = Torus2D(4, 1) + assert torus.nvertices == 4 + # 4 horizontal + 4 vertical wraps + assert torus.nedges == 8 + + +def test_torus2d_single_column(): + """Test Torus2D with a single column.""" + torus = Torus2D(1, 4) + assert torus.nvertices == 4 + # 4 horizontal wraps + 4 vertical + assert torus.nedges == 8 + + +def test_torus2d_square(): + """Test Torus2D with a square lattice.""" + torus = Torus2D(3, 3) + assert torus.nvertices == 9 + # 3 horizontal * 3 rows + 3 vertical * 3 columns = 18 + assert torus.nedges == 18 + + +def test_torus2d_edges(): + """Test that Torus2D creates correct edges including wrap-around.""" + torus = Torus2D(2, 2) + edges = list(torus.edges()) + # Should have 8 edges: 4 horizontal + 4 vertical + assert len(edges) == 8 + # Vertices: 0=(0,0), 1=(1,0), 2=(0,1), 3=(1,1) + edge_sets = [set(e.vertices) for e in edges] + # Horizontal edges (including wraps) + assert {0, 1} in edge_sets # (0,0)-(1,0) + assert {2, 3} in edge_sets # (0,1)-(1,1) + # Vertical edges (including wraps) + assert {0, 2} in edge_sets # (0,0)-(0,1) + assert {1, 3} in edge_sets # (1,0)-(1,1) + + +def test_torus2d_vertices(): + """Test that Torus2D vertices are correct.""" + torus = Torus2D(3, 2) + vertices = list(torus.vertices()) + assert vertices == [0, 1, 2, 3, 4, 5] + + +def test_torus2d_with_self_loops(): + """Test Torus2D with self-loops enabled.""" + torus = Torus2D(3, 2, self_loops=True) + assert torus.nvertices == 6 + # 6 self-loops + 12 nearest-neighbor edges = 18 + assert torus.nedges == 18 + + +def test_torus2d_self_loops_edges(): + """Test that self-loop edges are created correctly.""" + torus = Torus2D(2, 2, self_loops=True) + edge_vertices = {edge.vertices for edge in torus.edges()} + assert {(0,), (1,), (2,), (3,)}.issubset(edge_vertices) + + +def test_torus2d_coloring_without_self_loops(): + """Test edge coloring without self-loops.""" + torus = Torus2D(4, 4) + coloring = torus.edge_coloring() + # Should have 4 colors: horizontal even/odd (0,1), vertical even/odd (2,3) + assert coloring.ncolors == 4 + + +def test_torus2d_coloring_with_self_loops(): + """Test edge coloring with self-loops.""" + torus = Torus2D(3, 3, self_loops=True) + coloring = torus.edge_coloring() + # Odd periodic dimensions require dedicated wrap colors. + assert coloring.ncolors == 6 + + +def test_torus2d_coloring_non_overlapping(): + """Test that edges with the same color don't share vertices.""" + torus = Torus2D(4, 4) + coloring = torus.edge_coloring() + # Group edges by color + colors = {} + for edge in torus.edges(): + color = coloring.color(edge.vertices) + assert color is not None + edge_vertices = edge.vertices + if color not in colors: + colors[color] = [] + colors[color].append(edge_vertices) + # Check each color group + for color, edge_list in colors.items(): + used_vertices = set() + for vertices in edge_list: + assert not any(v in used_vertices for v in vertices) + used_vertices.update(vertices) + + +def test_torus2d_str(): + """Test string representation.""" + torus = Torus2D(3, 2) + assert str(torus) == "3x2 lattice torus with 6 vertices and 12 edges" + + +def test_torus2d_vs_patch2d_edge_count(): + """Test that torus has more edges than patch of same dimensions.""" + for width in range(2, 5): + for height in range(2, 5): + patch = Patch2D(width, height) + torus = Torus2D(width, height) + # Torus has width + height extra edges (wrapping) + assert torus.nedges == patch.nedges + width + height + + +def test_patch2d_inherits_hypergraph(): + """Test that Patch2D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + patch = Patch2D(3, 3) + assert isinstance(patch, Hypergraph) + # Test inherited methods work + assert hasattr(patch, "edges") + assert hasattr(patch, "vertices") + coloring = patch.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") + + +def test_torus2d_inherits_hypergraph(): + """Test that Torus2D is a Hypergraph subclass with all methods.""" + from qsharp.magnets.utilities import Hypergraph + + torus = Torus2D(3, 3) + assert isinstance(torus, Hypergraph) + # Test inherited methods work + assert hasattr(torus, "edges") + assert hasattr(torus, "vertices") + coloring = torus.edge_coloring() + assert isinstance(coloring, HypergraphEdgeColoring) + assert hasattr(coloring, "edges_of_color") diff --git a/source/pip/tests/magnets/test_model.py b/source/pip/tests/magnets/test_model.py new file mode 100755 index 0000000000..31913b62ad --- /dev/null +++ b/source/pip/tests/magnets/test_model.py @@ -0,0 +1,254 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pyright: reportPrivateImportUsage=false + +"""Unit tests for the Model classes.""" + +from __future__ import annotations + +import pytest + +from qsharp.magnets.models import IsingModel, Model +from qsharp.magnets.models.model import HeisenbergModel +from qsharp.magnets.utilities import ( + Hyperedge, + Hypergraph, + PauliString, +) + + +def make_chain(length: int) -> Hypergraph: + edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] + return Hypergraph(edges) + + +def make_chain_with_vertices(length: int) -> Hypergraph: + edges = [Hyperedge([i, i + 1]) for i in range(length - 1)] + edges.extend([Hyperedge([i]) for i in range(length)]) + return Hypergraph(edges) + + +class CountingColoringHypergraph(Hypergraph): + def __init__(self, edges: list[Hyperedge]): + super().__init__(edges) + self.edge_coloring_calls = 0 + + def edge_coloring(self): + self.edge_coloring_calls += 1 + return super().edge_coloring() + + +def test_model_init_basic(): + geometry = Hypergraph([Hyperedge([0, 1]), Hyperedge([1, 2])]) + model = Model(geometry) + assert model.geometry is geometry + assert model.nqubits == 3 + assert model.nterms == 0 + assert model._ops == [] + assert model._terms == {} + + +def test_model_init_empty_geometry(): + model = Model(Hypergraph([])) + assert model.nqubits == 0 + assert model.nterms == 0 + + +def test_model_add_interaction_basic(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.5) + + assert len(model._ops) == 1 + assert model._ops[0] == PauliString.from_qubits((0, 1), "ZZ", -1.5) + assert model.nterms == 0 + + +def test_model_add_interaction_with_term(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -2.0, term=3) + + assert model.nterms == 1 + assert 3 in model._terms + assert model._terms[3] == {0: [0]} + + +def test_model_term_color_query_methods(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.0, term=1, color=2) + model.add_interaction(edge, "XX", -0.5, term=1, color=2) + model.add_interaction(edge, "YY", -0.25, term=1, color=3) + + assert model.terms == [1] + assert model.ncolors(1) == 2 + assert set(model.colors(1)) == {2, 3} + assert model.nops(1, 2) == 2 + assert model.nops(1, 3) == 1 + assert model.ops(1, 2) == [ + PauliString.from_qubits((0, 1), "ZZ", -1.0), + PauliString.from_qubits((0, 1), "XX", -0.5), + ] + assert model.ops(1, 3) == [PauliString.from_qubits((0, 1), "YY", -0.25)] + + +def test_model_query_methods_raise_for_missing_term_and_color(): + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -1.0, term=0, color=0) + + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.ncolors(99) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.colors(99) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.nops(99, 0) + with pytest.raises(ValueError, match="Term 99 does not exist in the model"): + model.ops(99, 0) + + with pytest.raises(ValueError, match="Color 7 does not exist in term 0"): + model.nops(0, 7) + with pytest.raises(ValueError, match="Color 7 does not exist in term 0"): + model.ops(0, 7) + + +def test_model_add_interaction_rejects_edge_not_in_geometry(): + model = Model(Hypergraph([Hyperedge([0, 1])])) + with pytest.raises(ValueError, match="Edge is not part of the model geometry"): + model.add_interaction(Hyperedge([1, 2]), "ZZ", -1.0) + + +def test_model_str_and_repr(): + model = Model(make_chain(3)) + assert "0 terms" in str(model) + assert "3 qubits" in str(model) + assert repr(model) == str(model) + + +def test_ising_model_basic(): + geometry = make_chain_with_vertices(3) + model = IsingModel(geometry, h=1.0, J=1.0) + + assert isinstance(model, Model) + assert model.geometry is geometry + assert model.nterms == 2 + assert set(model._terms.keys()) == {0, 1} + + +def test_ising_model_str_and_repr(): + geometry = make_chain_with_vertices(3) + model = IsingModel(geometry, h=0.5, J=2.0) + + assert str(model) == "Ising model with 2 terms on 3 qubits (h=0.5, J=2.0)." + assert repr(model) == "IsingModel(nqubits=3, nterms=2, h=0.5, J=2.0)" + + +def test_heisenberg_model_str_and_repr(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=1.5) + + assert str(model) == "Heisenberg model with 3 terms on 3 qubits (J=1.5)." + assert repr(model) == "HeisenbergModel(nqubits=3, nterms=3, J=1.5)" + + +def test_ising_model_coloring_matches_geometry_coloring(): + geometry = make_chain_with_vertices(4) + model = IsingModel(geometry, h=1.0, J=1.0) + geometry_coloring = geometry.edge_coloring() + + for color, indices in model._terms[1].items(): + for index in indices: + op = model._ops[index] + edge_vertices = tuple(sorted(op.qubits)) + assert geometry_coloring.color(edge_vertices) == color + + +def test_ising_model_initialization_calls_geometry_edge_coloring_once(): + geometry = CountingColoringHypergraph( + [ + Hyperedge([0, 1]), + Hyperedge([1, 2]), + Hyperedge([0]), + Hyperedge([1]), + Hyperedge([2]), + ] + ) + + model = IsingModel(geometry, h=1.0, J=1.0) + + assert isinstance(model, IsingModel) + assert geometry.edge_coloring_calls == 1 + + +def test_ising_model_coefficients_and_paulis(): + geometry = make_chain_with_vertices(3) + model = IsingModel(geometry, h=0.5, J=2.0) + + ops_by_qubits = {tuple(sorted(op.qubits)): op for op in model._ops} + + assert ops_by_qubits[(0, 1)] == PauliString.from_qubits((0, 1), "ZZ", -2.0) + assert ops_by_qubits[(1, 2)] == PauliString.from_qubits((1, 2), "ZZ", -2.0) + assert ops_by_qubits[(0,)] == PauliString.from_qubits((0,), "X", -0.5) + assert ops_by_qubits[(1,)] == PauliString.from_qubits((1,), "X", -0.5) + assert ops_by_qubits[(2,)] == PauliString.from_qubits((2,), "X", -0.5) + + +def test_ising_model_term_grouping_indices(): + geometry = make_chain_with_vertices(4) + model = IsingModel(geometry, h=1.0, J=1.0) + + assert set(model._terms.keys()) == {0, 1} + assert all( + len(model._ops[index].qubits) == 1 + for indices in model._terms[0].values() + for index in indices + ) + assert all( + len(model._ops[index].qubits) == 2 + for indices in model._terms[1].values() + for index in indices + ) + + +def test_heisenberg_model_basic(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=1.0) + + assert isinstance(model, Model) + assert model.geometry is geometry + assert model.nterms == 3 + assert set(model._terms.keys()) == {0, 1, 2} + + +def test_heisenberg_model_coefficients_and_paulis(): + geometry = make_chain(3) + model = HeisenbergModel(geometry, J=2.5) + + expected = [ + PauliString.from_qubits((0, 1), "XX", -2.5), + PauliString.from_qubits((1, 2), "XX", -2.5), + PauliString.from_qubits((0, 1), "YY", -2.5), + PauliString.from_qubits((1, 2), "YY", -2.5), + PauliString.from_qubits((0, 1), "ZZ", -2.5), + PauliString.from_qubits((1, 2), "ZZ", -2.5), + ] + for pauli in expected: + assert pauli in model._ops + + +def test_heisenberg_model_term_grouping_colors_and_paulis(): + geometry = make_chain(4) + model = HeisenbergModel(geometry, J=1.0) + + paulis_by_term = {0: "XX", 1: "YY", 2: "ZZ"} + for term, pauli in paulis_by_term.items(): + for color, indices in model._terms[term].items(): + for index in indices: + op = model._ops[index] + expected = PauliString.from_qubits( + tuple(sorted(op.qubits)), pauli, -1.0 + ) + assert op == expected + assert model.coloring.color(tuple(sorted(op.qubits))) == color diff --git a/source/pip/tests/magnets/test_pauli.py b/source/pip/tests/magnets/test_pauli.py new file mode 100644 index 0000000000..03d29e8a96 --- /dev/null +++ b/source/pip/tests/magnets/test_pauli.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for Pauli and PauliString utilities.""" + +import pytest + +cirq = pytest.importorskip("cirq") + +from qsharp.magnets.utilities import Pauli, PauliString, PauliX, PauliY, PauliZ + + +def test_pauli_init_from_int_and_string(): + """Test Pauli initialization from int and case-insensitive string labels.""" + p_i = Pauli(0, qubit=1) + p_x = Pauli("x", qubit=2) + p_z = Pauli(2, qubit=3) + p_y = Pauli("Y", qubit=4) + + assert p_i.op == 0 and p_i.qubit == 1 + assert p_x.op == 1 and p_x.qubit == 2 + assert p_z.op == 2 and p_z.qubit == 3 + assert p_y.op == 3 and p_y.qubit == 4 + + +@pytest.mark.parametrize("value", [-1, 4, 42]) +def test_pauli_invalid_int_raises(value: int): + """Test invalid integer Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="Integer value must be 0-3"): + Pauli(value) + + +def test_pauli_invalid_string_raises(): + """Test invalid string Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="String value must be one of"): + Pauli("A") + + +def test_pauli_invalid_type_raises(): + """Test non-int/non-str Pauli identifiers raise ValueError.""" + with pytest.raises(ValueError, match="Expected int or str"): + Pauli(1.5) + + +def test_pauli_helpers_create_expected_operator(): + """Test PauliX/PauliY/PauliZ helper constructors.""" + assert PauliX(0) == Pauli("X", 0) + assert PauliY(1) == Pauli("Y", 1) + assert PauliZ(2) == Pauli("Z", 2) + + +def test_pauli_cirq_property_returns_operation_on_line_qubit(): + """Test Pauli.cirq returns a Cirq operation on the target qubit.""" + q = cirq.LineQubit(3) + assert Pauli("I", 3).cirq == cirq.I.on(q) + assert Pauli("X", 3).cirq == cirq.X.on(q) + assert Pauli("Y", 3).cirq == cirq.Y.on(q) + assert Pauli("Z", 3).cirq == cirq.Z.on(q) + + +def test_pauli_string_init_requires_pauli_instances(): + """Test PauliString initializer validates element types.""" + with pytest.raises(TypeError, match="Expected Pauli instance"): + PauliString([PauliX(0), "Z"]) + + +def test_pauli_string_from_qubits_accepts_string_and_int_values(): + """Test PauliString.from_qubits accepts both string and int identifiers.""" + from_string = PauliString.from_qubits((0, 1, 2), "XZY", coefficient=-1j) + from_ints = PauliString.from_qubits((0, 1, 2), [1, 2, 3], coefficient=-1j) + + assert from_string == from_ints + assert len(from_string) == 3 + assert from_string.qubits == (0, 1, 2) + + +def test_pauli_string_from_qubits_length_mismatch_raises(): + """Test from_qubits raises when qubit/value lengths differ.""" + with pytest.raises(ValueError, match="Length mismatch"): + PauliString.from_qubits((0, 1), "XYZ") + + +def test_pauli_string_sequence_protocol_and_indexing(): + """Test iteration, len, and indexing behavior.""" + ps = PauliString([PauliX(0), PauliZ(2)], coefficient=2.0) + + assert ps.qubits == (0, 2) + assert len(ps) == 2 + assert ps[0] == PauliX(0) + assert list(ps) == [PauliX(0), PauliZ(2)] + + +def test_pauli_string_equality_and_hash_include_coefficient(): + """Test equality/hash depend on Pauli terms and coefficient.""" + p1 = PauliString.from_qubits((0, 1), "XZ", coefficient=1.0) + p2 = PauliString.from_qubits((0, 1), "XZ", coefficient=1.0) + p3 = PauliString.from_qubits((0, 1), "XZ", coefficient=-1.0) + + assert p1 == p2 + assert hash(p1) == hash(p2) + assert p1 != p3 + + +def test_pauli_string_mul_scales_coefficient_and_preserves_terms(): + """Test PauliString.__mul__ returns scaled coefficient with same operators.""" + ps = PauliString.from_qubits((0, 2), "XZ", coefficient=2.0) + + scaled = ps * (-0.25j) + + assert scaled.qubits == ps.qubits + assert list(scaled) == list(ps) + assert scaled.coefficient == -0.5j + assert ps.coefficient == 2.0 + + +def test_pauli_string_cirq_property_preserves_terms_and_coefficient(): + """Test PauliString.cirq conversion with coefficient.""" + ps = PauliString.from_qubits((0, 2), "XZ", coefficient=-0.5j) + + expected = cirq.PauliString( + {cirq.LineQubit(0): cirq.X, cirq.LineQubit(2): cirq.Z}, + coefficient=-0.5j, + ) + + assert ps.cirq == expected diff --git a/source/pip/tests/magnets/test_trotter.py b/source/pip/tests/magnets/test_trotter.py new file mode 100644 index 0000000000..2bab4fa3c8 --- /dev/null +++ b/source/pip/tests/magnets/test_trotter.py @@ -0,0 +1,450 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for Trotter-Suzuki decomposition classes and factory functions.""" + +import pytest + +from qsharp.magnets.models import Model +from qsharp.magnets.utilities import Hyperedge, Hypergraph, PauliString + +from qsharp.magnets.trotter import ( + TrotterStep, + TrotterExpansion, + strang_splitting, + suzuki_recursion, + yoshida_recursion, + fourth_order_trotter_suzuki, +) + + +def make_two_term_model() -> Model: + edge = Hyperedge([0, 1]) + model = Model(Hypergraph([edge])) + model.add_interaction(edge, "ZZ", -2.0, term=0, color=0) + model.add_interaction(edge, "XX", -0.5, term=1, color=0) + return model + + +# TrotterStep tests + + +def test_trotter_step_empty_init(): + """Test that TrotterStep initializes as empty.""" + trotter = TrotterStep() + assert trotter.nterms == 0 + assert trotter.time_step == 0.0 + assert trotter.order == 0 + assert list(trotter.step()) == [] + + +def test_trotter_step_reduce_combines_consecutive(): + """Test that reduce combines consecutive same-term entries.""" + trotter = TrotterStep() + trotter.terms = [(0.5, 0), (0.5, 0), (0.5, 1)] + trotter.reduce() + assert list(trotter.step()) == [(1.0, 0), (0.5, 1)] + + +def test_trotter_step_reduce_no_change_when_different(): + """Test that reduce does not change non-consecutive same terms.""" + trotter = TrotterStep() + trotter.terms = [(0.5, 0), (0.5, 1), (0.5, 0)] + trotter.reduce() + assert list(trotter.step()) == [(0.5, 0), (0.5, 1), (0.5, 0)] + + +def test_trotter_step_reduce_empty(): + """Test that reduce handles empty terms.""" + trotter = TrotterStep() + trotter.reduce() + assert list(trotter.step()) == [] + + +# first-order TrotterStep constructor tests + + +def test_trotter_step_from_explicit_terms_basic(): + """Test basic TrotterStep creation from explicit term indices.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) + assert trotter.nterms == 3 + assert trotter.time_step == 0.5 + assert trotter.order == 1 + + +def test_trotter_step_first_order_single_term(): + """Test TrotterStep with a single explicit term.""" + trotter = TrotterStep(terms=[7], time_step=1.0) + result = list(trotter.step()) + assert result == [(1.0, 7)] + + +def test_trotter_step_first_order_multiple_terms(): + """Test TrotterStep with multiple explicit terms.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) + result = list(trotter.step()) + assert result == [(0.5, 0), (0.5, 1), (0.5, 2)] + + +def test_trotter_step_first_order_zero_time(): + """Test TrotterStep with zero time.""" + trotter = TrotterStep(terms=[0, 1], time_step=0.0) + result = list(trotter.step()) + assert result == [(0.0, 0), (0.0, 1)] + + +def test_trotter_step_first_order_returns_all_terms(): + """Test that TrotterStep returns all provided term indices in order.""" + terms = [2, 4, 7, 11, 15] + trotter = TrotterStep(terms=terms, time_step=1.0) + result = list(trotter.step()) + assert len(result) == len(terms) + term_indices = [idx for _, idx in result] + assert term_indices == terms + + +def test_trotter_step_first_order_uniform_time(): + """Test that all entries have the same configured time.""" + time = 0.25 + trotter = TrotterStep(terms=[0, 1, 2, 3], time_step=time) + result = list(trotter.step()) + for t, _ in result: + assert t == time + + +def test_trotter_step_first_order_str(): + """Test string representation of TrotterStep.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) + result = str(trotter) + assert "order" in result.lower() or "1" in result + + +def test_trotter_step_first_order_repr(): + """Test repr representation of TrotterStep.""" + trotter = TrotterStep(terms=[0, 1, 2], time_step=0.5) + assert "TrotterStep" in repr(trotter) + + +# strang_splitting factory tests + + +def test_strang_splitting_basic(): + """Test basic strang_splitting creation.""" + strang = strang_splitting(terms=[0, 1, 2], time=0.5) + assert strang.nterms == 3 + assert strang.time_step == 0.5 + assert strang.order == 2 + + +def test_strang_splitting_single_term(): + """Test strang_splitting with a single term.""" + strang = strang_splitting(terms=[0], time=1.0) + result = list(strang.step()) + # Single term: just full time on term 0 + assert result == [(1.0, 0)] + + +def test_strang_splitting_two_terms(): + """Test strang_splitting with two terms.""" + strang = strang_splitting(terms=[0, 1], time=1.0) + result = list(strang.step()) + # Forward: half on term 0, full on term 1, backward: half on term 0 + assert result == [(0.5, 0), (1.0, 1), (0.5, 0)] + + +def test_strang_splitting_three_terms(): + """Test strang_splitting with three terms (example from docstring).""" + strang = strang_splitting(terms=[0, 1, 2], time=0.5) + result = list(strang.step()) + expected = [(0.25, 0), (0.25, 1), (0.5, 2), (0.25, 1), (0.25, 0)] + assert result == expected + + +def test_strang_splitting_symmetric(): + """Test that strang_splitting produces symmetric sequence.""" + strang = strang_splitting(terms=[0, 1, 2, 3], time=1.0) + result = list(strang.step()) + # Check symmetry: term indices should be palindromic + term_indices = [idx for _, idx in result] + assert term_indices == term_indices[::-1] + + +def test_strang_splitting_time_sum(): + """Test that total time in strang_splitting equals expected value.""" + time = 1.0 + terms = [0, 1, 2] + strang = strang_splitting(terms=terms, time=time) + result = list(strang.step()) + total_time = sum(t for t, _ in result) + # Each term appears once with full time equivalent + # (half + half for outer terms, full for middle) + assert abs(total_time - time * len(terms)) < 1e-10 + + +def test_strang_splitting_middle_term_full_time(): + """Test that the middle term gets full time step.""" + strang = strang_splitting(terms=[0, 1, 2, 3, 4], time=2.0) + result = list(strang.step()) + # Middle term (index 4, the last term) should have full time + middle_entries = [(t, idx) for t, idx in result if idx == 4] + assert len(middle_entries) == 1 + assert middle_entries[0][0] == 2.0 + + +def test_strang_splitting_outer_terms_half_time(): + """Test that outer terms get half time steps.""" + strang = strang_splitting(terms=[0, 1, 2, 3], time=2.0) + result = list(strang.step()) + # Term 0 should appear twice with half time each + term_0_entries = [(t, idx) for t, idx in result if idx == 0] + assert len(term_0_entries) == 2 + for t, _ in term_0_entries: + assert t == 1.0 + + +def test_strang_splitting_repr(): + """Test repr representation of strang_splitting result.""" + strang = strang_splitting(terms=[0, 1, 2], time=0.5) + assert "StrangSplitting" in repr(strang) + + +# suzuki_recursion tests + + +def test_suzuki_recursion_from_strang(): + """Test Suzuki recursion applied to Strang splitting produces 4th order.""" + strang = strang_splitting(terms=[0, 1], time=1.0) + suzuki = suzuki_recursion(strang) + assert suzuki.order == 4 + assert suzuki.nterms == 2 + assert suzuki.time_step == 1.0 + + +def test_suzuki_recursion_from_first_order(): + """Test Suzuki recursion applied to first-order Trotter produces 3rd order.""" + trotter = TrotterStep(terms=[0, 1], time_step=1.0) + suzuki = suzuki_recursion(trotter) + assert suzuki.order == 3 + assert suzuki.nterms == 2 + + +def test_suzuki_recursion_preserves_nterms(): + """Test that Suzuki recursion preserves number of terms.""" + base = strang_splitting(terms=[0, 1, 2, 3, 4], time=0.5) + suzuki = suzuki_recursion(base) + assert suzuki.nterms == base.nterms + + +def test_suzuki_recursion_preserves_time_step(): + """Test that Suzuki recursion preserves time step.""" + base = strang_splitting(terms=[0, 1, 2], time=0.75) + suzuki = suzuki_recursion(base) + assert suzuki.time_step == base.time_step + + +def test_suzuki_recursion_repr(): + """Test repr of Suzuki recursion result.""" + base = strang_splitting(terms=[0, 1], time=1.0) + suzuki = suzuki_recursion(base) + assert "SuzukiRecursion" in repr(suzuki) + + +def test_suzuki_recursion_time_weights_sum(): + """Test that time weights in Suzuki recursion sum correctly.""" + base = TrotterStep(terms=[0, 1], time_step=1.0) + suzuki = suzuki_recursion(base) + # The total scaled time should equal the original total time * nterms + # because we're scaling times, not adding them + result = list(suzuki.step()) + total_time = sum(t for t, _ in result) + # For Suzuki: 5 copies scaled by p, p, (1-4p), p, p + # where weights sum to 4p + (1-4p) = 1, so total = base total + base_total = sum(t for t, _ in base.step()) + assert abs(total_time - base_total) < 1e-10 + + +# yoshida_recursion tests + + +def test_yoshida_recursion_from_strang(): + """Test Yoshida recursion applied to Strang splitting produces 4th order.""" + strang = strang_splitting(terms=[0, 1], time=1.0) + yoshida = yoshida_recursion(strang) + assert yoshida.order == 4 + assert yoshida.nterms == 2 + assert yoshida.time_step == 1.0 + + +def test_yoshida_recursion_from_first_order(): + """Test Yoshida recursion applied to first-order Trotter produces 3rd order.""" + trotter = TrotterStep(terms=[0, 1], time_step=1.0) + yoshida = yoshida_recursion(trotter) + assert yoshida.order == 3 + assert yoshida.nterms == 2 + + +def test_yoshida_recursion_preserves_nterms(): + """Test that Yoshida recursion preserves number of terms.""" + base = strang_splitting(terms=[0, 1, 2, 3, 4], time=0.5) + yoshida = yoshida_recursion(base) + assert yoshida.nterms == base.nterms + + +def test_yoshida_recursion_preserves_time_step(): + """Test that Yoshida recursion preserves time step.""" + base = strang_splitting(terms=[0, 1, 2], time=0.75) + yoshida = yoshida_recursion(base) + assert yoshida.time_step == base.time_step + + +def test_yoshida_recursion_repr(): + """Test repr of Yoshida recursion result.""" + base = strang_splitting(terms=[0, 1], time=1.0) + yoshida = yoshida_recursion(base) + assert "YoshidaRecursion" in repr(yoshida) + + +def test_yoshida_recursion_time_weights_sum(): + """Test that time weights in Yoshida recursion sum correctly.""" + base = TrotterStep(terms=[0, 1], time_step=1.0) + yoshida = yoshida_recursion(base) + # The total scaled time should equal the original total time * nterms + # because weights w1 + w0 + w1 = 2*w1 + w0 = 2*w1 + (1 - 2*w1) = 1 + result = list(yoshida.step()) + total_time = sum(t for t, _ in result) + base_total = sum(t for t, _ in base.step()) + assert abs(total_time - base_total) < 1e-10 + + +def test_yoshida_fewer_terms_than_suzuki(): + """Test that Yoshida produces fewer terms than Suzuki (3x vs 5x).""" + base = strang_splitting(terms=[0, 1, 2], time=1.0) + suzuki = suzuki_recursion(base) + yoshida = yoshida_recursion(base) + # Yoshida uses 3 copies, Suzuki uses 5 copies + # After reduction, Yoshida should generally have fewer terms + assert len(list(yoshida.step())) <= len(list(suzuki.step())) + + +# fourth_order_trotter_suzuki tests + + +def test_fourth_order_trotter_suzuki_basic(): + """Test fourth_order_trotter_suzuki factory function.""" + fourth = fourth_order_trotter_suzuki(terms=[0, 1], time=1.0) + assert fourth.order == 4 + assert fourth.nterms == 2 + assert fourth.time_step == 1.0 + + +def test_fourth_order_trotter_suzuki_equals_suzuki_of_strang(): + """Test that fourth_order_trotter_suzuki equals suzuki_recursion(strang_splitting).""" + fourth = fourth_order_trotter_suzuki(terms=[0, 1, 2], time=0.5) + manual = suzuki_recursion(strang_splitting(terms=[0, 1, 2], time=0.5)) + assert list(fourth.step()) == list(manual.step()) + assert fourth.order == manual.order + + +# TrotterExpansion tests + + +def test_trotter_expansion_order_property(): + """Test TrotterExpansion order property.""" + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=1.0, num_steps=4) + assert expansion.order == 2 + + +def test_trotter_expansion_nterms_property(): + """Test TrotterExpansion nterms property.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) + assert expansion.nterms == 2 + + +def test_trotter_expansion_num_steps_property(): + """Test TrotterExpansion num_steps property.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=8) + assert expansion.nsteps == 8 + + +def test_trotter_expansion_total_time_property(): + """Test TrotterExpansion total_time property.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) + assert expansion.total_time == 1.0 + + +def test_trotter_expansion_step_iterator(): + """Test TrotterExpansion.step() yields scaled PauliStrings.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.2, num_steps=3) + result = list(expansion.step()) + + # dt = 1.2 / 3 = 0.4 and model terms are 0->ZZ(-2.0), 1->XX(-0.5) + expected = [ + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ((0, 1), "ZZ", -0.8), + ((0, 1), "XX", -0.2), + ] + assert len(result) == len(expected) + for op, (qubits, paulis, coefficient) in zip(result, expected): + assert op.qubits == qubits + assert op.paulis == paulis + assert op.coefficient == pytest.approx(coefficient) + + +def test_trotter_expansion_step_iterator_with_strang(): + """Test TrotterExpansion.step() with Strang splitting schedule.""" + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=2.0, num_steps=2) + result = list(expansion.step()) + + # dt = 1.0; one Strang step over terms [0,1] is: + # (0.5,0), (1.0,1), (0.5,0) + expected_single = [ + PauliString.from_qubits((0, 1), "ZZ", -1.0), + PauliString.from_qubits((0, 1), "XX", -0.5), + PauliString.from_qubits((0, 1), "ZZ", -1.0), + ] + expected = expected_single * 2 + assert result == expected + + +def test_trotter_expansion_str(): + """Test TrotterExpansion string representation.""" + model = make_two_term_model() + expansion = TrotterExpansion(strang_splitting, model, time=1.0, num_steps=4) + result = str(expansion) + assert "order=2" in result + assert "num_steps=4" in result + assert "total_time=1.0" in result + assert "num_terms=2" in result + + +def test_trotter_expansion_repr(): + """Test TrotterExpansion repr representation.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=4) + result = repr(expansion) + assert "TrotterExpansion" in result + assert "num_steps=4" in result + + +def test_trotter_expansion_cirq_repetitions(): + """Test that TrotterExpansion.cirq repeats one-step circuit num_steps times.""" + model = make_two_term_model() + expansion = TrotterExpansion(TrotterStep, model, time=1.0, num_steps=5) + + op = expansion.cirq() + assert op.repetitions == 5 + + +def test_strang_splitting_rejects_empty_terms(): + """Test strang_splitting raises for empty term list.""" + with pytest.raises(IndexError): + strang_splitting([], time=1.0) diff --git a/source/pip/tests/qre/__init__.py b/source/pip/tests/qre/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/source/pip/tests/qre/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/source/pip/tests/qre/conftest.py b/source/pip/tests/qre/conftest.py new file mode 100644 index 0000000000..c779e6ff31 --- /dev/null +++ b/source/pip/tests/qre/conftest.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator + +from qsharp.qre import ( + ISA, + LOGICAL, + ISARequirements, + ISATransform, + constraint, +) +from qsharp.qre._architecture import ISAContext +from qsharp.qre.instruction_ids import LATTICE_SURGERY, T + + +# NOTE These classes will be generalized as part of the QRE API in the following +# pull requests and then moved out of the tests. + + +@dataclass +class ExampleFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(T), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ) + + +@dataclass +class ExampleLogicalFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(LATTICE_SURGERY, encoding=LOGICAL), + constraint(T, encoding=LOGICAL), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), + ) diff --git a/source/pip/tests/qre/test_application.py b/source/pip/tests/qre/test_application.py new file mode 100644 index 0000000000..c62537ee88 --- /dev/null +++ b/source/pip/tests/qre/test_application.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field + +import qsharp + +from qsharp.qre import ( + Application, + ISA, + LOGICAL, + PSSPC, + EstimationResult, + LatticeSurgery, + Trace, + linear_function, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre.application import QSharpApplication +from qsharp.qre.instruction_ids import CCX, LATTICE_SURGERY, T, RZ +from qsharp.qre.property_keys import ( + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, +) + + +def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + """Assert that an estimation result matches expected qubit, runtime, and error values.""" + actual_qubits = ( + isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + + isa[T].expect_space() * result.factories[T].copies + ) + if CCX in trace.resource_states: + actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies + assert result.qubits == actual_qubits + + assert ( + result.runtime + == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth + ) + + actual_error = ( + trace.base_error + + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth + + isa[T].expect_error_rate() * result.factories[T].states + ) + if CCX in trace.resource_states: + actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states + assert abs(result.error - actual_error) <= 1e-8 + + +def test_trace_properties(): + """Test setting and getting typed properties on a Trace.""" + trace = Trace(42) + + INT = 0 + FLOAT = 1 + BOOL = 2 + STR = 3 + + trace.set_property(INT, 42) + assert trace.get_property(INT) == 42 + assert isinstance(trace.get_property(INT), int) + + trace.set_property(FLOAT, 3.14) + assert trace.get_property(FLOAT) == 3.14 + assert isinstance(trace.get_property(FLOAT), float) + + trace.set_property(BOOL, True) + assert trace.get_property(BOOL) is True + assert isinstance(trace.get_property(BOOL), bool) + + trace.set_property(STR, "hello") + assert trace.get_property(STR) == "hello" + assert isinstance(trace.get_property(STR), str) + + +def test_qsharp_application(): + """Test QSharpApplication trace generation and estimation from a Q# program.""" + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + trace = app.get_trace() + + assert trace.compute_qubits == 3 + assert trace.depth == 3 + assert trace.resource_states == {} + + assert {c.id for c in trace.required_isa} == {CCX, T, RZ} + + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(50), + error_rate=linear_function(1e-6), + ), + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + # Properties from the program + counts = qsharp.logical_counts(code) + num_ts = counts["tCount"] + num_ccx = counts["cczCount"] + num_rotations = counts["rotationCount"] + rotation_depth = counts["rotationDepth"] + + lattice_surgery = LatticeSurgery() + + counter = 0 + for psspc in _enumerate_instances(PSSPC): + counter += 1 + trace2 = psspc.transform(trace) + assert trace2 is not None + trace2 = lattice_surgery.transform(trace2) + assert trace2 is not None + assert trace2.compute_qubits == 12 + assert ( + trace2.depth + == num_ts + + num_ccx * 3 + + num_rotations + + rotation_depth * psspc.num_ts_per_rotation + ) + if psspc.ccx_magic_states: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations, + CCX: num_ccx, + } + assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} + else: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx + } + assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} + assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 + assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 + result = trace2.estimate(isa, max_error=float("inf")) + assert result is not None + assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 + assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 + assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 + assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 + _assert_estimation_result(trace2, result, isa) + assert counter == 32 + + +def test_application_enumeration(): + """Test that Application.q() enumerates the correct number of traces.""" + @dataclass(kw_only=True) + class _Params: + size: int = field(default=1, metadata={"domain": range(1, 4)}) + + class TestApp(Application[_Params]): + def get_trace(self, parameters: _Params) -> Trace: + return Trace(parameters.size) + + app = TestApp() + assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 + assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 + assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 + + +def test_trace_enumeration(): + """Test trace query enumeration with PSSPC and LatticeSurgery transforms.""" + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + + ctx = app.context() + assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 + + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 + + assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 + + q = PSSPC.q() * LatticeSurgery.q() + assert sum(1 for _ in q.enumerate(ctx)) == 32 + + +def test_rotation_error_psspc(): + """Test that PSSPC base error stays below 1.0 for a single rotation gate.""" + # This test helps to bound the variables for the number of rotations in PSSPC + + # Create a trace with a single rotation gate and ensure that the base error + # after PSSPC transformation is less than 1. + trace = Trace(1) + trace.add_operation(RZ, [0]) + + for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): + transformed = psspc.transform(trace) + assert transformed is not None + assert ( + transformed.base_error < 1.0 + ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" diff --git a/source/pip/tests/qre/test_cirq_interop.py b/source/pip/tests/qre/test_cirq_interop.py new file mode 100644 index 0000000000..826ac54c7e --- /dev/null +++ b/source/pip/tests/qre/test_cirq_interop.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import cirq +from qsharp.qre.application import CirqApplication + + +def test_with_qft(): + """Test trace generation from a 1025-qubit QFT circuit.""" + _test_one_circuit(cirq.qft(*cirq.LineQubit.range(1025)), 1025, 212602, 266007) + + +def test_h(): + """Test trace generation from Hadamard and fractional Hadamard gates.""" + _test_one_circuit(cirq.H, 1, 1, 1) + _test_one_circuit(cirq.H**0.5, 1, 3, 3) + + +def test_cx(): + """Test trace generation from CX and fractional CX gates.""" + _test_one_circuit(cirq.CX, 2, 1, 1) + _test_one_circuit(cirq.CX**0.5, 2, 6, 7) + _test_one_circuit(cirq.CX**0.25, 2, 6, 7) + + +def test_cz(): + """Test trace generation from CZ and fractional CZ gates.""" + _test_one_circuit(cirq.CZ, 2, 1, 1) + _test_one_circuit(cirq.CZ**0.5, 2, 4, 5) + _test_one_circuit(cirq.CZ**0.25, 2, 4, 5) + + +def test_swap(): + """Test trace generation from SWAP and fractional SWAP gates.""" + _test_one_circuit(cirq.SWAP, 2, 1, 1) + _test_one_circuit(cirq.SWAP**0.5, 2, 8, 9) + + +def test_ccx(): + """Test trace generation from CCX and fractional CCX gates.""" + _test_one_circuit(cirq.CCX, 3, 1, 1) + _test_one_circuit(cirq.CCX**0.5, 3, 11, 17) + + +def test_ccz(): + """Test trace generation from CCZ and fractional CCZ gates.""" + _test_one_circuit(cirq.CCZ, 3, 1, 1) + _test_one_circuit(cirq.CCZ**0.5, 3, 10, 15) + + +def test_circuit_with_block(): + """Test trace generation from a circuit with a custom decomposable gate.""" + class CustomGate(cirq.Gate): + def num_qubits(self) -> int: + return 2 + + def _decompose_(self, qubits): + a, b = qubits + yield cirq.CX(a, b) + yield cirq.CX(b, a) + yield cirq.CX(a, b) + + q0, q1 = cirq.LineQubit.range(2) + _test_one_circuit( + [ + cirq.H.on_each(q0, q1), + CustomGate().on(q0, q1), + ], + 2, + 4, + 5, + ) + + +def _test_one_circuit( + circuit: cirq.CIRCUIT_LIKE, + expected_qubits: int, + expected_depth: int, + expected_gates: int, +): + """Assert that a Cirq circuit produces a trace with the expected qubits, depth, and gates.""" + app = CirqApplication(circuit) + trace = app.get_trace() + + assert trace.total_qubits == expected_qubits, "unexpected number of qubits in trace" + assert trace.depth == expected_depth, "unexpected depth of trace" + assert trace.num_gates == expected_gates, "unexpected number of gates in trace" diff --git a/source/pip/tests/qre/test_enumeration.py b/source/pip/tests/qre/test_enumeration.py new file mode 100644 index 0000000000..636f982699 --- /dev/null +++ b/source/pip/tests/qre/test_enumeration.py @@ -0,0 +1,542 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from enum import Enum +from typing import cast + +import pytest + +from qsharp.qre import LOGICAL +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._isa_enumeration import ( + ISARefNode, + _ComponentQuery, + _ProductNode, + _SumNode, +) + +from .conftest import ExampleFactory, ExampleLogicalFactory + + +def test_enumerate_instances(): + """Test enumeration of SurfaceCode instances with default and custom domains.""" + from qsharp.qre._enumeration import _enumerate_instances + + instances = list(_enumerate_instances(SurfaceCode)) + + # There are 12 instances with distances from 3 to 25 + assert len(instances) == 12 + expected_distances = list(range(3, 26, 2)) + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with specific distances + instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) + assert len(instances) == 3 + expected_distances = [3, 5, 7] + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with fixed distance + instances = list(_enumerate_instances(SurfaceCode, distance=9)) + assert len(instances) == 1 + assert instances[0].distance == 9 + + +def test_enumerate_instances_bool(): + """Test that boolean dataclass fields enumerate both True and False.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class BoolConfig: + _: KW_ONLY + flag: bool + + instances = list(_enumerate_instances(BoolConfig)) + assert len(instances) == 2 + assert instances[0].flag is True + assert instances[1].flag is False + + +def test_enumerate_instances_enum(): + """Test that Enum dataclass fields enumerate all members.""" + from qsharp.qre._enumeration import _enumerate_instances + + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + @dataclass + class EnumConfig: + _: KW_ONLY + color: Color + + instances = list(_enumerate_instances(EnumConfig)) + assert len(instances) == 3 + assert instances[0].color == Color.RED + assert instances[1].color == Color.GREEN + assert instances[2].color == Color.BLUE + + +def test_enumerate_instances_failure(): + """Test that a field with no domain and no default raises ValueError.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InvalidConfig: + _: KW_ONLY + # This field has no domain, is not bool/enum, and has no default + value: int + + with pytest.raises(ValueError, match="Cannot enumerate field value"): + list(_enumerate_instances(InvalidConfig)) + + +def test_enumerate_instances_single(): + """Test enumeration of a dataclass with a single non-kw-only field.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class SingleConfig: + value: int = 42 + + instances = list(_enumerate_instances(SingleConfig)) + assert len(instances) == 1 + assert instances[0].value == 42 + + +def test_enumerate_instances_literal(): + """Test that Literal-typed fields enumerate their allowed values.""" + from qsharp.qre._enumeration import _enumerate_instances + + from typing import Literal + + @dataclass + class LiteralConfig: + _: KW_ONLY + mode: Literal["fast", "slow"] + + instances = list(_enumerate_instances(LiteralConfig)) + assert len(instances) == 2 + assert instances[0].mode == "fast" + assert instances[1].mode == "slow" + + +def test_enumerate_instances_nested(): + """Test enumeration of nested dataclass fields.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + instances = list(_enumerate_instances(OuterConfig)) + assert len(instances) == 2 + assert instances[0].inner.option is True + assert instances[1].inner.option is False + + +def test_enumerate_instances_union(): + """Test enumeration of union-typed dataclass fields.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + instances = list(_enumerate_instances(UnionConfig)) + assert len(instances) == 5 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert isinstance(instances[2].option, OptionB) + assert instances[2].option.number == 1 + + +def test_enumerate_instances_nested_with_constraints(): + """Test constraining nested dataclass fields via a dict.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + # Constrain nested field via dict + instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) + assert len(instances) == 1 + assert instances[0].inner.option is True + + +def test_enumerate_instances_union_single_type(): + """Test restricting a union field to a single member type.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Restrict to OptionB only - uses its default domain + instances = list(_enumerate_instances(UnionConfig, option=OptionB)) + assert len(instances) == 3 + assert all(isinstance(i.option, OptionB) for i in instances) + assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] + + # Restrict to OptionA only + instances = list(_enumerate_instances(UnionConfig, option=OptionA)) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionA) for i in instances) + assert cast(OptionA, instances[0].option).value is True + assert cast(OptionA, instances[1].option).value is False + + +def test_enumerate_instances_union_list_of_types(): + """Test restricting a union field to a subset of member types.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class OptionC: + _: KW_ONLY + flag: bool + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB | OptionC + + # Select a subset: only OptionA and OptionB + instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) + assert len(instances) == 5 # 2 from OptionA + 3 from OptionB + assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) + + +def test_enumerate_instances_union_constraint_dict(): + """Test constraining union field members via a type-to-kwargs dict.""" + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Constrain OptionA, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) + ) + assert len(instances) == 1 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + + # Constrain OptionB with a domain, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) + ) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionB) for i in instances) + assert cast(OptionB, instances[0].option).number == 2 + assert cast(OptionB, instances[1].option).number == 3 + + # Constrain one member and keep another with defaults + instances = list( + _enumerate_instances( + UnionConfig, + option={OptionA: {"value": True}, OptionB: {}}, + ) + ) + assert len(instances) == 4 # 1 from OptionA + 3 from OptionB + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert all(isinstance(i.option, OptionB) for i in instances[1:]) + assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] + + +def test_enumerate_isas(): + """Test ISA enumeration with products, sums, and hierarchical factories.""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # This will enumerate the 4 ISAs for the error correction code + count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) + assert count == 12 + + # This will enumerate the 2 ISAs for the error correction code when + # restricting the domain + count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) + assert count == 2 + + # This will enumerate the 3 ISAs for the factory + count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) + assert count == 3 + + # This will enumerate 36 ISAs for all products between the 12 error + # correction code ISAs and the 3 factory ISAs + count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) + assert count == 36 + + # When providing a list, components are chained (OR operation). This + # enumerates ISAs from first factory instance OR second factory instance + count = sum( + 1 + for _ in ( + SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) + ).enumerate(ctx) + ) + assert count == 72 + + # When providing separate arguments, components are combined via product + # (AND). This enumerates ISAs from first factory instance AND second + # factory instance + count = sum( + 1 + for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( + ctx + ) + ) + assert count == 108 + + # Hierarchical factory using from_components: the component receives ISAs + # from the product of other components as its source + count = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) + ).enumerate(ctx) + ) + assert count == 1296 + + +def test_binding_node(): + """Test binding nodes with ISARefNode for component bindings""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test basic binding: same code used twice + # Without binding: 12 codes × 12 codes = 144 combinations + count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) + assert count_without == 144 + + # With binding: 12 codes (same instance used twice) + count_with = sum( + 1 + for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) + ) + assert count_with == 12 + + # Verify the binding works: with binding, both should use same params + for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + # Should have 1 logical gate (LATTICE_SURGERY) + assert len(logical_gates) == 1 + + # Test binding with factories (nested bindings) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 3 * 12 * 3 + + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ExampleFactory.bind( + "f", + ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), + ), + ).enumerate(ctx) + ) + assert count_with == 36 # 12 * 3 + + # Test binding with from_components equivalent (hierarchical) + # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q( + source=(SurfaceCode.q() * ExampleFactory.q()), + ) + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 12 * 3 * 3 + + # With binding: 4 codes (same used twice) × 3 factories × 3 levels + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ISARefNode("c") + * ExampleLogicalFactory.q( + source=(ISARefNode("c") * ExampleFactory.q()), + ), + ).enumerate(ctx) + ) + assert count_with == 108 # 12 * 3 * 3 + + # Test binding with kwargs + count_with_kwargs = sum( + 1 + for _ in SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ) + assert count_with_kwargs == 1 # Only distance=5 + + # Verify kwargs are applied + for isa in ( + SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + assert all(g.space(1) == 49 for g in logical_gates) + + # Test multiple independent bindings (nested) + count = sum( + 1 + for _ in SurfaceCode.bind( + "c1", + ExampleFactory.bind( + "c2", + ISARefNode("c1") + * ISARefNode("c1") + * ISARefNode("c2") + * ISARefNode("c2"), + ), + ).enumerate(ctx) + ) + # 12 codes for c1 × 3 factories for c2 + assert count == 36 + + +def test_binding_node_errors(): + """Test error handling for binding nodes""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test ISARefNode enumerate with undefined binding raises ValueError + try: + list(ISARefNode("test").enumerate(ctx)) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Undefined component reference: 'test'" in str(e) + + +def test_product_isa_enumeration_nodes(): + """Test that multiplying ISAQuery nodes produces flattened ProductNodes.""" + terminal = SurfaceCode.q() + query = terminal * terminal + + # Multiplication should create ProductNode + assert isinstance(query, _ProductNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query * terminal + assert isinstance(query, _ProductNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two ProductNodes + query = query * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + +def test_sum_isa_enumeration_nodes(): + """Test that adding ISAQuery nodes produces flattened SumNodes.""" + terminal = SurfaceCode.q() + query = terminal + terminal + + # Multiplication should create SumNode + assert isinstance(query, _SumNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query + terminal + assert isinstance(query, _SumNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two SumNodes + query = query + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) diff --git a/source/pip/tests/qre/test_estimation.py b/source/pip/tests/qre/test_estimation.py new file mode 100644 index 0000000000..2947db2075 --- /dev/null +++ b/source/pip/tests/qre/test_estimation.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import pytest + +from qsharp.estimator import LogicalCounts +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import ( + SurfaceCode, + GateBased, + RoundBasedFactory, + TwoDimensionalYokedSurfaceCode, +) + +from .conftest import ExampleFactory + + +def test_estimation_max_error(): + """Test that estimation results respect the max_error constraint.""" + app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) + arch = GateBased(gate_time=50, measurement_time=100) + + for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=max_error, + ) + + assert len(results) == 1 + assert next(iter(results)).error <= max_error + + +@pytest.mark.skipif( + "SLOW_TESTS" not in os.environ, + reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", +) +@pytest.mark.parametrize( + "post_process, use_graph", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_estimation_methods(post_process, use_graph): + """Test all combinations of post_process and use_graph estimation paths.""" + counts = LogicalCounts( + { + "numQubits": 1000, + "tCount": 1_500_000, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 1_000_000_000, + "ccixCount": 0, + "measurementCount": 25_000_000, + "numComputeQubits": 200, + "readFromMemoryCount": 30_000_000, + "writeToMemoryCount": 30_000_000, + } + ) + + trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) + isa_query = ( + SurfaceCode.q() + * RoundBasedFactory.q() + * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) + ) + + app = QSharpApplication(counts) + arch = GateBased(gate_time=50, measurement_time=100) + + results = estimate( + app, + arch, + isa_query, + trace_query, + max_error=1 / 3, + post_process=post_process, + use_graph=use_graph, + ) + results.add_factory_summary_column() + + assert [(result.qubits, result.runtime) for result in results] == [ + (238707, 23997050000000), + (240407, 11998525000000), + ] + + print() + print(results.stats) diff --git a/source/pip/tests/qre/test_estimation_table.py b/source/pip/tests/qre/test_estimation_table.py new file mode 100644 index 0000000000..744a3dc607 --- /dev/null +++ b/source/pip/tests/qre/test_estimation_table.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest +import pandas as pd + +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._estimation import ( + EstimationTable, + EstimationTableEntry, +) +from qsharp.qre._instruction import InstructionSource +from qsharp.qre.instruction_ids import LATTICE_SURGERY +from qsharp.qre.property_keys import DISTANCE, NUM_TS_PER_ROTATION + +from .conftest import ExampleFactory + + +def _make_entry(qubits, runtime, error, properties=None): + """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" + return EstimationTableEntry( + qubits=qubits, + runtime=runtime, + error=error, + source=InstructionSource(), + properties=properties or {}, + ) + + +def test_estimation_table_default_columns(): + """Test that a new EstimationTable has the three default columns.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error"] + assert frame["qubits"][0] == 100 + assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") + assert frame["error"][0] == 0.01 + + +def test_estimation_table_multiple_rows(): + """Test as_frame with multiple entries.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + table.append(_make_entry(300, 15000, 0.03)) + + frame = table.as_frame() + assert len(frame) == 3 + assert list(frame["qubits"]) == [100, 200, 300] + assert list(frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_empty(): + """Test as_frame with no entries produces an empty DataFrame.""" + table = EstimationTable() + frame = table.as_frame() + assert len(frame) == 0 + + +def test_estimation_table_add_column(): + """Test adding a column to the table.""" + VAL = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) + table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) + + table.add_column("val", lambda e: e.properties[VAL]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "val"] + assert list(frame["val"]) == [42, 84] + + +def test_estimation_table_add_column_with_formatter(): + """Test adding a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) + + table.add_column( + "duration", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") + + +def test_estimation_table_add_multiple_columns(): + """Test adding multiple columns preserves order.""" + A = 0 + B = 1 + C = 2 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) + + table.add_column("a", lambda e: e.properties[A]) + table.add_column("b", lambda e: e.properties[B]) + table.add_column("c", lambda e: e.properties[C]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] + assert frame["a"][0] == 1 + assert frame["b"][0] == 2 + assert frame["c"][0] == 3 + + +def test_estimation_table_insert_column_at_beginning(): + """Test inserting a column at index 0.""" + NAME = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) + + table.insert_column(0, "name", lambda e: e.properties[NAME]) + + frame = table.as_frame() + assert list(frame.columns) == ["name", "qubits", "runtime", "error"] + assert frame["name"][0] == "test" + + +def test_estimation_table_insert_column_in_middle(): + """Test inserting a column between existing default columns.""" + EXTRA = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) + + # Insert between qubits and runtime (index 1) + table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] + assert frame["extra"][0] == 99 + + +def test_estimation_table_insert_column_at_end(): + """Test inserting a column at the end (same effect as add_column).""" + LAST = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) + + # 3 default columns, inserting at index 3 = end + table.insert_column(3, "last", lambda e: e.properties[LAST]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "last"] + assert frame["last"][0] + + +def test_estimation_table_insert_column_with_formatter(): + """Test inserting a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) + + table.insert_column( + 0, + "custom_time", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") + assert list(frame.columns)[0] == "custom_time" + + +def test_estimation_table_insert_and_add_columns(): + """Test combining insert_column and add_column.""" + A = 0 + B = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) + + table.add_column("b", lambda e: e.properties[B]) + table.insert_column(0, "a", lambda e: e.properties[A]) + + frame = table.as_frame() + assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] + + +def test_estimation_table_factory_summary_no_factories(): + """Test factory summary column when entries have no factories.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + table.add_factory_summary_column() + + frame = table.as_frame() + assert "factories" in frame.columns + assert frame["factories"][0] == "None" + + +def test_estimation_table_factory_summary_with_estimation(): + """Test factory summary column with real estimation results.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_factory_summary_column() + frame = results.as_frame() + + assert "factories" in frame.columns + # Each result should mention T in the factory summary + for val in frame["factories"]: + assert "T" in val + + +def test_estimation_table_add_column_from_source(): + """Test adding a column that accesses the InstructionSource (like distance).""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "compute_distance", + lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], + ) + + frame = results.as_frame() + assert "compute_distance" in frame.columns + for d in frame["compute_distance"]: + assert isinstance(d, int) + assert d >= 3 + + +def test_estimation_table_add_column_from_properties(): + """Test adding columns that access trace properties from estimation.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "num_ts_per_rotation", + lambda entry: entry.properties[NUM_TS_PER_ROTATION], + ) + + frame = results.as_frame() + assert "num_ts_per_rotation" in frame.columns + for val in frame["num_ts_per_rotation"]: + assert isinstance(val, int) + assert val >= 1 + + +def test_estimation_table_insert_column_before_defaults(): + """Test inserting a name column before all default columns, similar to the factoring notebook.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + name="test_experiment", + ) + + assert len(results) >= 1 + + # Add a factory summary at the end + results.add_factory_summary_column() + + frame = results.as_frame() + assert frame.columns[0] == "name" + assert frame.columns[-1] == "factories" + # Default columns should still be in order + assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] + + +def test_estimation_table_as_frame_sortable(): + """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" + table = EstimationTable() + table.append(_make_entry(300, 15000, 0.03)) + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + + frame = table.as_frame() + sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) + + assert list(sorted_frame["qubits"]) == [100, 200, 300] + assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_computed_column(): + """Test adding a column that computes a derived value from the entry.""" + table = EstimationTable() + table.append(_make_entry(100, 5_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000, 0.02)) + + # Compute qubits * error as a derived metric + table.add_column("qubit_error_product", lambda e: e.qubits * e.error) + + frame = table.as_frame() + assert frame["qubit_error_product"][0] == pytest.approx(1.0) + assert frame["qubit_error_product"][1] == pytest.approx(4.0) diff --git a/source/pip/tests/qre/test_interop.py b/source/pip/tests/qre/test_interop.py new file mode 100644 index 0000000000..4e94d6f549 --- /dev/null +++ b/source/pip/tests/qre/test_interop.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path + +import pytest + +from qsharp.qre.interop import trace_from_qir + + +def _ll_files(): + """Return the list of QIR .ll test files.""" + ll_dir = ( + Path(__file__).parent.parent.parent + / "tests-integration" + / "resources" + / "adaptive_ri" + / "output" + ) + return sorted(ll_dir.glob("*.ll")) + + +@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) +def test_trace_from_qir(ll_file): + """Test that trace_from_qir can parse real QIR output files.""" + # NOTE: This test is primarily to ensure that the function can parse real + # QIR output without errors, rather than checking specific properties of the + # trace. + try: + trace_from_qir(ll_file.read_text()) + except ValueError as e: + # The only reason of failure is presence of control flow + assert ( + str(e) + == "simulation of programs with branching control flow is not supported" + ) + + +def test_trace_from_qir_handles_all_instruction_ids(): + """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. + + Generates a synthetic QIR program containing one instance of each gate + intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir + processes all of them without error. + """ + import pyqir + import pyqir.qis as qis + from qsharp._native import QirInstructionId + from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP + + # -- Completeness check: every QirInstructionId must be covered -------- + handled_ids = ( + [qir_id for qir_id, _, _ in _GATE_MAP] + + [qir_id for qir_id, _ in _MEAS_MAP] + + list(_SKIP) + ) + # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) + all_ids = [ + QirInstructionId.I, + QirInstructionId.H, + QirInstructionId.X, + QirInstructionId.Y, + QirInstructionId.Z, + QirInstructionId.S, + QirInstructionId.SAdj, + QirInstructionId.SX, + QirInstructionId.SXAdj, + QirInstructionId.T, + QirInstructionId.TAdj, + QirInstructionId.CNOT, + QirInstructionId.CX, + QirInstructionId.CY, + QirInstructionId.CZ, + QirInstructionId.CCX, + QirInstructionId.SWAP, + QirInstructionId.RX, + QirInstructionId.RY, + QirInstructionId.RZ, + QirInstructionId.RXX, + QirInstructionId.RYY, + QirInstructionId.RZZ, + QirInstructionId.RESET, + QirInstructionId.M, + QirInstructionId.MResetZ, + QirInstructionId.MZ, + QirInstructionId.Move, + QirInstructionId.ReadResult, + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, + QirInstructionId.CorrelatedNoise, + ] + unhandled = [ + i + for i in all_ids + if i not in handled_ids and i != QirInstructionId.CorrelatedNoise + ] + assert unhandled == [], ( + f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " + f"{', '.join(str(i) for i in unhandled)}" + ) + + # -- Generate a QIR program with every producible gate ----------------- + simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) + builder = simple.builder + ctx = simple.context + q = simple.qubits + r = simple.results + + void_ty = pyqir.Type.void(ctx) + qubit_ty = pyqir.qubit_type(ctx) + result_ty = pyqir.result_type(ctx) + double_ty = pyqir.Type.double(ctx) + i64_ty = pyqir.IntType(ctx, 64) + + def declare(name, param_types): + return simple.add_external_function( + name, pyqir.FunctionType(void_ty, param_types) + ) + + # Single-qubit gates (pyqir.qis builtins) + qis.h(builder, q[0]) + qis.x(builder, q[0]) + qis.y(builder, q[0]) + qis.z(builder, q[0]) + qis.s(builder, q[0]) + qis.s_adj(builder, q[0]) + qis.t(builder, q[0]) + qis.t_adj(builder, q[0]) + + # SX — not in pyqir.qis + sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) + builder.call(sx_fn, [q[0]]) + + # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the + # pass does not handle, so use builder.call with the correct name) + cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) + builder.call(cx_fn, [q[0], q[1]]) + qis.cz(builder, q[0], q[1]) + qis.swap(builder, q[0], q[1]) + + cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) + builder.call(cy_fn, [q[0], q[1]]) + + # Three-qubit gate + qis.ccx(builder, q[0], q[1], q[2]) + + # Single-qubit rotations + qis.rx(builder, 1.0, q[0]) + qis.ry(builder, 1.0, q[0]) + qis.rz(builder, 1.0, q[0]) + + # Two-qubit rotations — not in pyqir.qis + rot2_ty = [double_ty, qubit_ty, qubit_ty] + angle = pyqir.const(double_ty, 1.0) + for name in ("rxx", "ryy", "rzz"): + fn = declare(f"__quantum__qis__{name}__body", rot2_ty) + builder.call(fn, [angle, q[0], q[1]]) + + # Measurements + qis.mz(builder, q[0], r[0]) + + m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) + builder.call(m_fn, [q[1], r[1]]) + + mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) + builder.call(mresetz_fn, [q[2], r[2]]) + + # Reset / Move + qis.reset(builder, q[0]) + + move_fn = declare("__quantum__qis__move__body", [qubit_ty]) + builder.call(move_fn, [q[0]]) + + # Output recording + tag = simple.add_byte_string(b"tag") + arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) + builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) + + rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) + builder.call(rec_fn, [r[0], tag]) + + tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) + builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) + + # -- Run trace_from_qir and verify it succeeds ------------------------- + trace = trace_from_qir(simple.ir()) + assert trace is not None + + +def test_rotation_buckets(): + """Test that rotation bucketization preserves total count and depth.""" + from qsharp.qre.interop._qsharp import _bucketize_rotation_counts + + print() + + r_count = 15066 + r_depth = 14756 + q_count = 291 + + result = _bucketize_rotation_counts(r_count, r_depth) + + a_count = 0 + a_depth = 0 + for c, d in result: + print(c, d) + assert c <= q_count + assert c > 0 + a_count += c * d + a_depth += d + + assert a_count == r_count + assert a_depth == r_depth diff --git a/source/pip/tests/qre/test_isa.py b/source/pip/tests/qre/test_isa.py new file mode 100644 index 0000000000..95809968b6 --- /dev/null +++ b/source/pip/tests/qre/test_isa.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from qsharp.qre import ( + LOGICAL, + ISARequirements, + constraint, + generic_function, + property_name, + property_name_to_key, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._architecture import _make_instruction +from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T +from qsharp.qre.property_keys import DISTANCE + + +def test_isa(): + """Test ISA creation, instruction lookup, and dynamic node addition.""" + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + assert T in isa + assert CCX in isa + assert LATTICE_SURGERY not in isa + + t_instr = isa[T] + assert t_instr.time() == 1000 + assert t_instr.error_rate() == 1e-8 + assert t_instr.space() == 400 + + assert len(isa) == 2 + ccz_instr = isa[CCX].with_id(CCZ) + assert ccz_instr.arity == 3 + assert ccz_instr.time() == 2000 + assert ccz_instr.error_rate() == 1e-10 + assert ccz_instr.space() == 800 + + # Add another instruction to the graph and register it in the ISA + ccz_node = graph.add_instruction(ccz_instr) + isa.add_node(CCZ, ccz_node) + assert CCZ in isa + assert len(isa) == 3 + + # Adding the same instruction ID should not increase the count + isa.add_node(CCZ, ccz_node) + assert len(isa) == 3 + + +def test_instruction_properties(): + """Test getting and setting instruction properties.""" + # Test instruction with no properties + instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) + assert instr_no_props.get_property(DISTANCE) is None + assert instr_no_props.has_property(DISTANCE) is False + assert instr_no_props.get_property_or(DISTANCE, 5) == 5 + + # Test instruction with valid property (distance) + instr_with_distance = _make_instruction( + T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} + ) + assert instr_with_distance.get_property(DISTANCE) == 9 + assert instr_with_distance.has_property(DISTANCE) is True + assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 + + # Test instruction with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) + + +def test_instruction_constraints(): + """Test constraint property filtering and ISA.satisfies behavior.""" + # Test constraint without properties + c_no_props = constraint(T, encoding=LOGICAL) + assert c_no_props.has_property(DISTANCE) is False + + # Test constraint with valid property (distance=True) + c_with_distance = constraint(T, encoding=LOGICAL, distance=True) + assert c_with_distance.has_property(DISTANCE) is True + + # Test constraint with distance=False (should not add the property) + c_distance_false = constraint(T, encoding=LOGICAL, distance=False) + assert c_distance_false.has_property(DISTANCE) is False + + # Test constraint with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + constraint(T, encoding=LOGICAL, invalid_prop=True) + + # Test ISA.satisfies with property constraints + graph = _ProvenanceGraph() + isa_no_dist = graph.make_isa( + [ + graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ] + ) + isa_with_dist = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ), + ] + ) + + reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) + reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) + + # ISA without distance property + assert isa_no_dist.satisfies(reqs_no_prop) is True + assert isa_no_dist.satisfies(reqs_with_prop) is False + + # ISA with distance property + assert isa_with_dist.satisfies(reqs_no_prop) is True + assert isa_with_dist.satisfies(reqs_with_prop) is True + + +def test_property_names(): + """Test property name lookup and case-insensitive key resolution.""" + assert property_name(DISTANCE) == "DISTANCE" + + # An unregistered property + UNKNOWN = 10_000 + assert property_name(UNKNOWN) is None + + # But using an existing property key with a different variable name will + # still return something + UNKNOWN = 0 + assert property_name(UNKNOWN) == "DISTANCE" + + assert property_name_to_key("DISTANCE") == DISTANCE + + # But we also allow case-insensitive lookup + assert property_name_to_key("distance") == DISTANCE + + +def test_generic_function(): + """Test generic_function wrapping for int and float return types.""" + from qsharp.qre._qre import _IntFunction, _FloatFunction + + def time(x: int) -> int: + return x * x + + time_fn = generic_function(time) + assert isinstance(time_fn, _IntFunction) + + def error_rate(x: int) -> float: + return x / 2.0 + + error_rate_fn = generic_function(error_rate) + assert isinstance(error_rate_fn, _FloatFunction) + + # Without annotations, defaults to FloatFunction + space_fn = generic_function(lambda x: 12) + assert isinstance(space_fn, _FloatFunction) + + i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) + assert i.space(5) == 12 + assert i.time(5) == 25 + assert i.error_rate(5) == 2.5 + + +def test_isa_from_architecture(): + """Test generating logical ISAs from an architecture and QEC code.""" + arch = GateBased(gate_time=50, measurement_time=100) + code = SurfaceCode() + ctx = arch.context() + + # Verify that the architecture satisfies the code requirements + assert ctx.isa.satisfies(SurfaceCode.required_isa()) + + # Generate logical ISAs + isas = list(code.provided_isa(ctx.isa, ctx)) + + # There is one ISA with one instructions + assert len(isas) == 1 + assert len(isas[0]) == 1 diff --git a/source/pip/tests/qre/test_models.py b/source/pip/tests/qre/test_models.py new file mode 100644 index 0000000000..46b236afb0 --- /dev/null +++ b/source/pip/tests/qre/test_models.py @@ -0,0 +1,1069 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from qsharp.qre import LOGICAL, PHYSICAL +from qsharp.qre.instruction_ids import ( + T, + CCZ, + CCX, + CCY, + CNOT, + CZ, + H, + MEAS_Z, + MEAS_X, + MEAS_XX, + MEAS_ZZ, + PAULI_I, + PREP_X, + PREP_Z, + LATTICE_SURGERY, + MEMORY, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, +) +from qsharp.qre.models import ( + GateBased, + Majorana, + RoundBasedFactory, + MagicUpToClifford, + Litinski19Factory, + SurfaceCode, + ThreeAux, + TwoDimensionalYokedSurfaceCode, +) +from qsharp.qre.property_keys import DISTANCE + + +# --------------------------------------------------------------------------- +# GateBased architecture tests +# --------------------------------------------------------------------------- + + +class TestGateBased: + def test_default_error_rate(self): + """Test that default error rate is 1e-4.""" + arch = GateBased(gate_time=50, measurement_time=100) + assert arch.error_rate == 1e-4 + + def test_custom_error_rate(self): + """Test that a custom error rate is accepted.""" + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + assert arch.error_rate == 1e-3 + + def test_provided_isa_contains_expected_instructions(self): + """Test that GateBased ISA contains all expected physical instructions.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert instr_id in isa + + def test_instruction_encodings_are_physical(self): + """Test that all GateBased ISA instructions have PHYSICAL encoding.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert isa[instr_id].encoding == PHYSICAL + + def test_instruction_error_rates_match(self): + """Test that all instruction error rates match the architecture error rate.""" + rate = 1e-3 + arch = GateBased(error_rate=rate, gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PAULI_I, CNOT, CZ, H, MEAS_Z, T]: + assert isa[instr_id].expect_error_rate() == rate + + def test_gate_times(self): + """Test that gate times match the configured gate and measurement times.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa + + # Single-qubit gates: 50ns + for instr_id in [PAULI_I, H, T]: + assert isa[instr_id].expect_time() == 50 + + # Two-qubit gates: 50ns + for instr_id in [CNOT, CZ]: + assert isa[instr_id].expect_time() == 50 + + # Measurement: 100ns + assert isa[MEAS_Z].expect_time() == 100 + + def test_arities(self): + """Test that instruction arities match expected values.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + isa = ctx.isa + + assert isa[PAULI_I].arity == 1 + assert isa[CNOT].arity == 2 + assert isa[CZ].arity == 2 + assert isa[H].arity == 1 + assert isa[MEAS_Z].arity == 1 + + def test_context_creation(self): + """Test that context creation succeeds.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + assert ctx is not None + + +# --------------------------------------------------------------------------- +# Majorana architecture tests +# --------------------------------------------------------------------------- + + +class TestMajorana: + def test_default_error_rate(self): + """Test that default Majorana error rate is 1e-5.""" + arch = Majorana() + assert arch.error_rate == 1e-5 + + def test_provided_isa_contains_expected_instructions(self): + """Test that Majorana ISA contains all expected instructions.""" + arch = Majorana() + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: + assert instr_id in isa + + def test_all_times_are_1us(self): + """Test that all Majorana instruction times are 1000 ns.""" + arch = Majorana() + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z, T]: + assert isa[instr_id].expect_time() == 1000 + + def test_clifford_error_rates_match_qubit_error(self): + """Test that Clifford error rates match the qubit error rate.""" + for rate in [1e-4, 1e-5, 1e-6]: + arch = Majorana(error_rate=rate) + ctx = arch.context() + isa = ctx.isa + + for instr_id in [PREP_X, PREP_Z, MEAS_XX, MEAS_ZZ, MEAS_X, MEAS_Z]: + assert isa[instr_id].expect_error_rate() == rate + + def test_t_error_rate_mapping(self): + """T error rate maps: 1e-4 -> 5%, 1e-5 -> 1.5%, 1e-6 -> 1%.""" + expected = {1e-4: 0.05, 1e-5: 0.015, 1e-6: 0.01} + + for qubit_rate, t_rate in expected.items(): + arch = Majorana(error_rate=qubit_rate) + ctx = arch.context() + isa = ctx.isa + assert isa[T].expect_error_rate() == t_rate + + def test_two_qubit_measurement_arities(self): + """Test that two-qubit measurement instructions have arity 2.""" + arch = Majorana() + ctx = arch.context() + isa = ctx.isa + + assert isa[MEAS_XX].arity == 2 + assert isa[MEAS_ZZ].arity == 2 + + +# --------------------------------------------------------------------------- +# SurfaceCode QEC tests +# --------------------------------------------------------------------------- + + +class TestSurfaceCode: + def test_required_isa(self): + """Test that SurfaceCode has non-None required ISA.""" + reqs = SurfaceCode.required_isa() + assert reqs is not None + + def test_default_distance(self): + """Test SurfaceCode with explicit distance parameter.""" + sc = SurfaceCode(distance=3) + assert sc.distance == 3 + + def test_provides_lattice_surgery(self): + """Test that SurfaceCode provides a logical LATTICE_SURGERY instruction.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + sc = SurfaceCode(distance=3) + + isas = list(sc.provided_isa(ctx.isa, ctx)) + assert len(isas) == 1 + + isa = isas[0] + assert LATTICE_SURGERY in isa + + ls = isa[LATTICE_SURGERY] + assert ls.encoding == LOGICAL + + def test_space_scales_with_distance(self): + """Space = 2*d^2 - 1 physical qubits per logical qubit.""" + arch = GateBased(gate_time=50, measurement_time=100) + + for d in [3, 5, 7, 9]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(ctx.isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected_space = 2 * d**2 - 1 + assert ls.expect_space(1) == expected_space + + def test_time_scales_with_distance(self): + """Time = (h_time + 4*cnot_time + meas_time) * d.""" + arch = GateBased(gate_time=50, measurement_time=100) + # h=50, cnot=50, meas=100 for GateBased + syndrome_time = 50 + 4 * 50 + 100 # = 350 + + for d in [3, 5, 7]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(ctx.isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + assert ls.expect_time(1) == syndrome_time * d + + def test_error_rate_decreases_with_distance(self): + """Test that logical error rate decreases as code distance increases.""" + arch = GateBased(gate_time=50, measurement_time=100) + + errors = [] + for d in [3, 5, 7, 9, 11]: + ctx = arch.context() + sc = SurfaceCode(distance=d) + isas = list(sc.provided_isa(ctx.isa, ctx)) + errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) + + # Each successive distance should have a lower error rate + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1] + + def test_enumeration_via_query(self): + """Enumerating SurfaceCode.q() should yield multiple distances.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + count = 0 + for isa in SurfaceCode.q().enumerate(ctx): + assert LATTICE_SURGERY in isa + count += 1 + + # domain is range(3, 26, 2) = 12 distances + assert count == 12 + + def test_custom_crossing_prefactor(self): + """Test that doubling the crossing prefactor doubles the error rate.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + sc_default = SurfaceCode(distance=5) + sc_custom = SurfaceCode(crossing_prefactor=0.06, distance=5) + + default_error = list(sc_default.provided_isa(ctx.isa, ctx))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + custom_error = list(sc_custom.provided_isa(ctx2.isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Doubling prefactor should double the error rate + assert abs(custom_error - 2 * default_error) < 1e-20 + + def test_custom_error_correction_threshold(self): + """Test that a lower error correction threshold yields a higher logical error.""" + arch = GateBased(gate_time=50, measurement_time=100) + + ctx1 = arch.context() + sc_low_threshold = SurfaceCode(error_correction_threshold=0.005, distance=5) + error_low = list(sc_low_threshold.provided_isa(ctx1.isa, ctx1))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + sc_high_threshold = SurfaceCode(error_correction_threshold=0.02, distance=5) + error_high = list(sc_high_threshold.provided_isa(ctx2.isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Lower threshold means worse ratio => higher logical error + assert error_low > error_high + + +# --------------------------------------------------------------------------- +# ThreeAux QEC tests +# --------------------------------------------------------------------------- + + +class TestThreeAux: + def test_required_isa(self): + """Test that ThreeAux has non-None required ISA.""" + reqs = ThreeAux.required_isa() + assert reqs is not None + + def test_provides_lattice_surgery(self): + """Test that ThreeAux provides a LATTICE_SURGERY instruction.""" + arch = Majorana() + ctx = arch.context() + ta = ThreeAux(distance=3) + + isas = list(ta.provided_isa(ctx.isa, ctx)) + assert len(isas) == 1 + assert LATTICE_SURGERY in isas[0] + + def test_space_formula(self): + """Space = 4*d^2 - 3 per logical qubit.""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d) + isas = list(ta.provided_isa(ctx.isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected = 4 * d**2 - 3 + assert ls.expect_space(1) == expected + + def test_time_formula_double_rail(self): + """Time = gate_time * (4*d + 4) for double-rail (default).""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d, single_rail=False) + isas = list(ta.provided_isa(ctx.isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + # MEAS_XX and MEAS_ZZ have time=1000 each; max is 1000 + expected_time = 1000 * (4 * d + 4) + assert ls.expect_time(1) == expected_time + + def test_time_formula_single_rail(self): + """Time = gate_time * (5*d + 4) for single-rail.""" + arch = Majorana() + + for d in [3, 5, 7]: + ctx = arch.context() + ta = ThreeAux(distance=d, single_rail=True) + isas = list(ta.provided_isa(ctx.isa, ctx)) + ls = isas[0][LATTICE_SURGERY] + expected_time = 1000 * (5 * d + 4) + assert ls.expect_time(1) == expected_time + + def test_error_rate_decreases_with_distance(self): + """Test that ThreeAux error rate decreases with increasing distance.""" + arch = Majorana() + + errors = [] + for d in [3, 5, 7, 9]: + ctx = arch.context() + ta = ThreeAux(distance=d) + isas = list(ta.provided_isa(ctx.isa, ctx)) + errors.append(isas[0][LATTICE_SURGERY].expect_error_rate(1)) + + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1] + + def test_single_rail_has_different_error_threshold(self): + """Single-rail has threshold 0.0051, double-rail 0.0066.""" + arch = Majorana() + + ctx1 = arch.context() + double = ThreeAux(distance=5, single_rail=False) + error_double = list(double.provided_isa(ctx1.isa, ctx1))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + ctx2 = arch.context() + single = ThreeAux(distance=5, single_rail=True) + error_single = list(single.provided_isa(ctx2.isa, ctx2))[0][ + LATTICE_SURGERY + ].expect_error_rate(1) + + # Both should be positive but differ + assert error_double > 0 + assert error_single > 0 + assert error_double != error_single + + def test_enumeration_via_query(self): + """Test that ThreeAux.q() enumerates all distance and rail combinations.""" + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in ThreeAux.q().enumerate(ctx): + assert LATTICE_SURGERY in isa + count += 1 + + # domain: range(3, 26, 2) × {True, False} for single_rail + # = 12 distances × 2 = 24 + assert count == 24 + + +# --------------------------------------------------------------------------- +# YokedSurfaceCode tests +# --------------------------------------------------------------------------- + + +class TestYokedSurfaceCode: + def _get_lattice_surgery_isa(self, distance=5): + """Helper to get a lattice surgery ISA from SurfaceCode.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + sc = SurfaceCode(distance=distance) + isas = list(sc.provided_isa(ctx.isa, ctx)) + return isas[0], ctx + + def test_provides_memory_instruction(self): + """Test that YokedSurfaceCode provides a MEMORY instruction.""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + assert len(isas) == 1 + assert MEMORY in isas[0] + + def test_memory_is_logical(self): + """Test that the MEMORY instruction has LOGICAL encoding.""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + assert mem.encoding == LOGICAL + + def test_memory_arity_is_variable(self): + """Test that MEMORY instruction has variable arity (None).""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + # arity=None means variable arity + assert mem.arity is None + + def test_space_increases_with_arity(self): + """Test that MEMORY space increases with the number of qubits.""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + spaces = [mem.expect_space(n) for n in [4, 16, 64]] + for i in range(len(spaces) - 1): + assert spaces[i] < spaces[i + 1] + + def test_time_increases_with_arity(self): + """Test that MEMORY time increases with the number of qubits.""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + times = [mem.expect_time(n) for n in [4, 16, 64]] + for i in range(len(times) - 1): + assert times[i] < times[i + 1] + + def test_error_rate_increases_with_arity(self): + """Test that MEMORY error rate increases with the number of qubits.""" + ls_isa, ctx = self._get_lattice_surgery_isa() + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + + errors = [mem.expect_error_rate(n) for n in [4, 16, 64]] + for i in range(len(errors) - 1): + assert errors[i] < errors[i + 1] + + def test_distance_property_propagated(self): + """Test that the distance property is propagated to the MEMORY instruction.""" + d = 7 + ls_isa, ctx = self._get_lattice_surgery_isa(distance=d) + ysc = TwoDimensionalYokedSurfaceCode() + + isas = list(ysc.provided_isa(ls_isa, ctx)) + mem = isas[0][MEMORY] + assert mem.get_property(DISTANCE) == d + + +# --------------------------------------------------------------------------- +# Litinski19Factory tests +# --------------------------------------------------------------------------- + + +class TestLitinski19Factory: + def test_required_isa(self): + """Test that Litinski19Factory has non-None required ISA.""" + reqs = Litinski19Factory.required_isa() + assert reqs is not None + + def test_table1_yields_t_and_ccz(self): + """GateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(ctx.isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + assert T in isa + assert CCZ in isa + assert len(isa) == 2 + + def test_table1_instruction_properties(self): + """Test that Table 1 T and CCZ instructions have valid properties.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + + for isa in factory.provided_isa(ctx.isa, ctx): + t_instr = isa[T] + ccz_instr = isa[CCZ] + + assert t_instr.arity == 1 + assert t_instr.encoding == LOGICAL + assert t_instr.expect_error_rate() > 0 + assert t_instr.expect_time() > 0 + assert t_instr.expect_space() > 0 + + assert ccz_instr.arity == 3 + assert ccz_instr.encoding == LOGICAL + assert ccz_instr.expect_error_rate() > 0 + + def test_table1_t_error_rates_are_diverse(self): + """T entries in Table 1 should span a range of error rates.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(ctx.isa, ctx)) + t_errors = [isa[T].expect_error_rate() for isa in isas] + + # Should have multiple distinct T error rates + unique_errors = set(t_errors) + assert len(unique_errors) > 1 + + # All error rates should be positive and very small + for err in t_errors: + assert 0 < err < 1e-5 + + def test_table1_1e3_clifford_yields_6_isas(self): + """GateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + + isas = list(factory.provided_isa(ctx.isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + assert T in isa + assert CCZ in isa + + def test_table2_scenario_no_ccz(self): + """Table 2 scenario: T error ~10x higher than Clifford, no CCZ.""" + from qsharp.qre._qre import _ProvenanceGraph + + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + # Manually create ISA with T error rate 10x Clifford + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction(CNOT, arity=2, time=50, error_rate=1e-4), + graph.add_instruction(H, time=50, error_rate=1e-4), + graph.add_instruction(MEAS_Z, time=100, error_rate=1e-4), + graph.add_instruction(T, time=50, error_rate=1e-3), + ] + ) + + factory = Litinski19Factory() + isas = list(factory.provided_isa(isa_input, ctx)) + + # Table 2 at 1e-4 Clifford: 4 T entries, no CCZ + assert len(isas) == 4 + + for isa in isas: + assert T in isa + assert CCZ not in isa + + def test_no_yield_when_error_too_high(self): + """If T error > 10x Clifford, no entries match.""" + from qsharp.qre._qre import _ProvenanceGraph + + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction(CNOT, arity=2, time=50, error_rate=1e-4), + graph.add_instruction(H, time=50, error_rate=1e-4), + graph.add_instruction(MEAS_Z, time=100, error_rate=1e-4), + graph.add_instruction(T, time=50, error_rate=0.05), + ] + ) + + factory = Litinski19Factory() + isas = list(factory.provided_isa(isa_input, ctx)) + assert len(isas) == 0 + + def test_time_based_on_syndrome_extraction(self): + """Time should be based on syndrome extraction time × cycles.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + + # For GateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 + syndrome_time = 4 * 50 + 50 + 100 # 350 ns + + isas = list(factory.provided_isa(ctx.isa, ctx)) + for isa in isas: + t_time = isa[T].expect_time() + assert t_time > 0 + # Time should be ceil(syndrome_time * cycles), so it must be at + # least syndrome_time (cycles >= 1) + assert t_time >= syndrome_time + + +# --------------------------------------------------------------------------- +# MagicUpToClifford tests +# --------------------------------------------------------------------------- + + +class TestMagicUpToClifford: + def test_required_isa_is_empty(self): + """Test that MagicUpToClifford has non-None required ISA.""" + reqs = MagicUpToClifford.required_isa() + assert reqs is not None + + def test_adds_clifford_equivalent_t_gates(self): + """Given T gate, should add SQRT_SQRT_X/Y/Z and dagger variants.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(ctx.isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + assert len(modified_isas) == 1 + modified_isa = modified_isas[0] + + # T family equivalents + for equiv_id in [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + SQRT_SQRT_Z_DAG, + ]: + assert equiv_id in modified_isa + + break # Just test the first one + + def test_adds_clifford_equivalent_ccz(self): + """Given CCZ, should add CCX and CCY.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(ctx.isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + modified_isa = modified_isas[0] + + assert CCX in modified_isa + assert CCY in modified_isa + assert CCZ in modified_isa + break + + def test_full_count_of_instructions(self): + """T gate (1) + 5 equivalents (SQRT_SQRT_*) + CCZ (1) + 2 equivalents (CCX, CCY) = 9.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(ctx.isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + assert len(modified_isas[0]) == 9 + break + + def test_equivalent_instructions_share_properties(self): + """Clifford equivalents should have same time, space, error rate.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + for isa in factory.provided_isa(ctx.isa, ctx): + modified_isas = list(modifier.provided_isa(isa, ctx)) + modified_isa = modified_isas[0] + + t_instr = modified_isa[T] + for equiv_id in [ + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z_DAG, + ]: + equiv = modified_isa[equiv_id] + assert equiv.expect_error_rate() == t_instr.expect_error_rate() + assert equiv.expect_time() == t_instr.expect_time() + assert equiv.expect_space() == t_instr.expect_space() + + ccz_instr = modified_isa[CCZ] + for equiv_id in [CCX, CCY]: + equiv = modified_isa[equiv_id] + assert equiv.expect_error_rate() == ccz_instr.expect_error_rate() + + break + + def test_modification_count_matches_factory_output(self): + """MagicUpToClifford should produce one modified ISA per input ISA.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + modified_count = 0 + for isa in factory.provided_isa(ctx.isa, ctx): + for _ in modifier.provided_isa(isa, ctx): + modified_count += 1 + + assert modified_count == 6 + + def test_no_family_present_passes_through(self): + """If no family member is present, ISA passes through unchanged.""" + from qsharp.qre._qre import _ProvenanceGraph + + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + modifier = MagicUpToClifford() + + # ISA with only a LATTICE_SURGERY instruction (no T or CCZ family) + from qsharp.qre import linear_function + + graph = _ProvenanceGraph() + isa_input = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(17), + error_rate=linear_function(1e-10), + ) + ] + ) + + results = list(modifier.provided_isa(isa_input, ctx)) + assert len(results) == 1 + # Should only contain the original instruction + assert len(results[0]) == 1 + + +# --------------------------------------------------------------------------- +# Litinski19Factory + MagicUpToClifford integration (from original test) +# --------------------------------------------------------------------------- + + +def test_isa_manipulation(): + """Test Litinski19Factory and MagicUpToClifford ISA integration.""" + arch = GateBased(gate_time=50, measurement_time=100) + factory = Litinski19Factory() + modifier = MagicUpToClifford() + + ctx = arch.context() + + # Table 1 scenario: should yield ISAs with both T and CCZ instructions + isas = list(factory.provided_isa(ctx.isa, ctx)) + + # 6 T entries × 1 CCZ entry = 6 combinations + assert len(isas) == 6 + + for isa in isas: + # Each ISA should contain both T and CCZ instructions + assert T in isa + assert CCZ in isa + assert len(isa) == 2 + + t_instr = isa[T] + ccz_instr = isa[CCZ] + + # Verify instruction properties + assert t_instr.arity == 1 + assert t_instr.encoding == LOGICAL + assert t_instr.expect_error_rate() > 0 + + assert ccz_instr.arity == 3 + assert ccz_instr.encoding == LOGICAL + assert ccz_instr.expect_error_rate() > 0 + + # After MagicUpToClifford modifier + modified_count = 0 + for isa in factory.provided_isa(ctx.isa, ctx): + for modified_isa in modifier.provided_isa(isa, ctx): + modified_count += 1 + # MagicUpToClifford should add derived instructions + assert T in modified_isa + assert CCZ in modified_isa + assert CCX in modified_isa + assert len(modified_isa) == 9 + + assert modified_count == 6 + + +# --------------------------------------------------------------------------- +# RoundBasedFactory tests +# --------------------------------------------------------------------------- + + +class TestRoundBasedFactory: + def test_required_isa(self): + """Test that RoundBasedFactory has non-None required ISA.""" + reqs = RoundBasedFactory.required_isa() + assert reqs is not None + + def test_produces_logical_t_gates(self): + """Test that RoundBasedFactory produces logical T gates with valid properties.""" + arch = GateBased(gate_time=50, measurement_time=100) + + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + t = isa[T] + assert t.encoding == LOGICAL + assert t.arity == 1 + assert t.expect_error_rate() > 0 + assert t.expect_time() > 0 + assert t.expect_space() > 0 + break # Just check the first + + def test_error_rates_are_bounded(self): + """Distilled T error rates should be bounded and mostly small.""" + arch = GateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 + + errors = [] + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + errors.append(isa[T].expect_error_rate()) + + # All should be positive + assert all(e > 0 for e in errors) + # Most distilled error rates should be much lower than 1 + assert min(errors) < 1e-4 + # Median should be well below raw physical error + sorted_errors = sorted(errors) + median = sorted_errors[len(sorted_errors) // 2] + assert median < 1e-3 + + def test_max_produces_fewer_or_equal_results_than_sum(self): + """Using max for physical_qubit_calculation may filter differently.""" + arch = GateBased(gate_time=50, measurement_time=100) + + sum_count = sum( + 1 for _ in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) + ) + max_count = sum( + 1 + for _ in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()) + ) + + assert max_count <= sum_count + + def test_max_space_less_than_or_equal_sum_space(self): + """max-aggregated space should be <= sum-aggregated space for each.""" + arch = GateBased(gate_time=50, measurement_time=100) + + sum_spaces = sorted( + isa[T].expect_space() + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) + ) + + max_spaces = sorted( + isa[T].expect_space() + for isa in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()) + ) + + # The minimum space with max should be <= minimum space with sum + assert max_spaces[0] <= sum_spaces[0] + + def test_with_three_aux_code_query(self): + """RoundBasedFactory with ThreeAux code query should produce results.""" + arch = Majorana() + + count = 0 + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(arch.context()): + assert T in isa + assert isa[T].encoding == LOGICAL + count += 1 + + assert count > 0 + + def test_round_based_gate_based_sum(self): + """Test RoundBasedFactory aggregated totals with GateBased sum mode.""" + arch = GateBased(gate_time=50, measurement_time=100) + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 12_946_488 + assert total_time == 12_032_250 + assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 + assert count == 107 + + def test_round_based_gate_based_max(self): + """Test RoundBasedFactory aggregated totals with GateBased max mode.""" + arch = GateBased(gate_time=50, measurement_time=100) + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q( + use_cache=False, physical_qubit_calculation=max + ).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 4_651_617 + assert total_time == 7_785_000 + assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 + assert count == 77 + + def test_round_based_msft_sum(self): + """Test RoundBasedFactory aggregated totals with Majorana sum mode.""" + arch = Majorana() + + total_space = 0 + total_time = 0 + total_error = 0.0 + count = 0 + + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(arch.context()): + count += 1 + total_space += isa[T].expect_space() + total_time += isa[T].expect_time() + total_error += isa[T].expect_error_rate() + + assert total_space == 255_952_723 + assert total_time == 478_235_000 + assert abs(total_error - 0.000_880_967_766_732_897_4) < 1e-8 + assert count == 301 + + +# --------------------------------------------------------------------------- +# Cross-model integration tests +# --------------------------------------------------------------------------- + + +class TestCrossModelIntegration: + def test_surface_code_feeds_into_litinski(self): + """SurfaceCode -> Litinski19Factory pipeline works end to end.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + # SurfaceCode takes gate-based physical ISA -> LATTICE_SURGERY + sc = SurfaceCode(distance=5) + sc_isas = list(sc.provided_isa(ctx.isa, ctx)) + assert len(sc_isas) == 1 + + # Litinski takes H, CNOT, MEAS_Z, T from the physical ISA + factory = Litinski19Factory() + factory_isas = list(factory.provided_isa(ctx.isa, ctx)) + assert len(factory_isas) > 0 + + def test_three_aux_feeds_into_round_based(self): + """ThreeAux -> RoundBasedFactory pipeline works.""" + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in RoundBasedFactory.q( + use_cache=False, code_query=ThreeAux.q() + ).enumerate(ctx): + assert T in isa + count += 1 + + assert count > 0 + + def test_litinski_with_magic_up_to_clifford_query(self): + """Full query chain: Litinski19Factory -> MagicUpToClifford.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + count = 0 + for isa in MagicUpToClifford.q(source=Litinski19Factory.q()).enumerate(ctx): + assert T in isa + assert CCX in isa + assert CCY in isa + assert CCZ in isa + count += 1 + + assert count == 6 + + def test_surface_code_with_yoked_surface_code(self): + """SurfaceCode -> YokedSurfaceCode pipeline provides MEMORY.""" + arch = GateBased(gate_time=50, measurement_time=100) + ctx = arch.context() + + count = 0 + for isa in TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()).enumerate( + ctx + ): + assert MEMORY in isa + count += 1 + + # 12 distances × 1 shape heuristic = 12 + assert count == 12 + + def test_majorana_three_aux_yoked(self): + """Majorana -> ThreeAux -> YokedSurfaceCode pipeline.""" + arch = Majorana() + ctx = arch.context() + + count = 0 + for isa in TwoDimensionalYokedSurfaceCode.q(source=ThreeAux.q()).enumerate(ctx): + assert MEMORY in isa + count += 1 + + assert count > 0 diff --git a/source/qre/Cargo.toml b/source/qre/Cargo.toml new file mode 100644 index 0000000000..37f28e2ffc --- /dev/null +++ b/source/qre/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "qre" + +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +num-traits = { workspace = true } +rustc-hash = { workspace = true } +probability = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } + + +[dev-dependencies] + +[lints] +workspace = true diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs new file mode 100644 index 0000000000..cdeac80524 --- /dev/null +++ b/source/qre/src/isa.rs @@ -0,0 +1,708 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::hash_map::Entry, + fmt::Display, + ops::Add, + sync::{Arc, RwLock, RwLockReadGuard}, +}; + +use num_traits::FromPrimitive; +use rustc_hash::{FxHashMap, FxHashSet}; +use serde::{Deserialize, Serialize}; + +use crate::trace::instruction_ids::instruction_name; + +pub mod property_keys; + +mod provenance; +pub use provenance::ProvenanceGraph; + +#[cfg(test)] +mod tests; + +#[derive(Clone)] +pub struct ISA { + graph: Arc>, + nodes: FxHashMap, +} + +impl Default for ISA { + fn default() -> Self { + ISA { + graph: Arc::new(RwLock::new(ProvenanceGraph::new())), + nodes: FxHashMap::default(), + } + } +} + +impl ISA { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Creates an ISA backed by the given shared provenance graph. + #[must_use] + pub fn with_graph(graph: Arc>) -> Self { + ISA { + graph, + nodes: FxHashMap::default(), + } + } + + /// Returns a reference to the shared provenance graph. + #[must_use] + pub fn graph(&self) -> &Arc> { + &self.graph + } + + /// Adds an instruction to the provenance graph and records its node index. + /// Returns the node index in the graph. + pub fn add_instruction(&mut self, instruction: Instruction) -> usize { + let id = instruction.id; + let mut graph = self.graph.write().expect("provenance graph lock poisoned"); + let node_idx = graph.add_node(instruction, 0, &[]); + self.nodes.insert(id, node_idx); + node_idx + } + + /// Records an existing provenance graph node in this ISA. + pub fn add_node(&mut self, instruction_id: u64, node_index: usize) { + self.nodes.insert(instruction_id, node_index); + } + + /// Returns the node index for an instruction ID, if present. + #[must_use] + pub fn node_index(&self, id: &u64) -> Option { + self.nodes.get(id).copied() + } + + /// Returns a clone of the instruction with the given ID, if present. + #[must_use] + pub fn get(&self, id: &u64) -> Option { + let &node_idx = self.nodes.get(id)?; + let graph = self.read_graph(); + Some(graph.instruction(node_idx).clone()) + } + + #[must_use] + pub fn contains(&self, id: &u64) -> bool { + self.nodes.contains_key(id) + } + + #[must_use] + pub fn len(&self) -> usize { + self.nodes.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Returns a read-locked view of this ISA, enabling zero-clone + /// instruction access for the lifetime of the returned guard. + #[must_use] + pub fn lock(&self) -> LockedISA<'_> { + LockedISA { + graph: self.read_graph(), + nodes: &self.nodes, + } + } + + fn read_graph(&self) -> RwLockReadGuard<'_, ProvenanceGraph> { + self.graph.read().expect("provenance graph lock poisoned") + } + + /// Returns an iterator over pairs of instruction IDs and node IDs. + pub fn node_entries(&self) -> impl Iterator { + self.nodes.iter() + } + + /// Returns all instructions as owned clones. + #[must_use] + pub fn instructions(&self) -> Vec { + let graph = self.read_graph(); + self.nodes + .values() + .map(|&idx| graph.instruction(idx).clone()) + .collect() + } + + #[must_use] + pub fn satisfies(&self, requirements: &ISARequirements) -> bool { + let graph = self.read_graph(); + for constraint in requirements.constraints.values() { + let Some(&node_idx) = self.nodes.get(&constraint.id) else { + return false; + }; + + let instruction = graph.instruction(node_idx); + + if !constraint.is_satisfied_by(instruction) { + return false; + } + } + true + } +} + +impl FromIterator for ISA { + fn from_iter>(iter: T) -> Self { + let mut isa = ISA::new(); + for instruction in iter { + isa.add_instruction(instruction); + } + isa + } +} + +impl Display for ISA { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let graph = self.read_graph(); + for &node_idx in self.nodes.values() { + let instruction = graph.instruction(node_idx); + writeln!(f, "{instruction}")?; + } + Ok(()) + } +} + +impl Add for ISA { + type Output = ISA; + + fn add(self, other: ISA) -> ISA { + let mut combined = self; + if Arc::ptr_eq(&combined.graph, &other.graph) { + // Same graph: just merge node maps + for (id, node_idx) in other.nodes { + combined.nodes.insert(id, node_idx); + } + } else { + // Different graphs: copy instructions into combined's graph + let other_graph = other.read_graph(); + let mut self_graph = combined + .graph + .write() + .expect("provenance graph lock poisoned"); + for (id, node_idx) in &other.nodes { + let instruction = other_graph.instruction(*node_idx).clone(); + let new_idx = self_graph.add_node(instruction, 0, &[]); + combined.nodes.insert(*id, new_idx); + } + } + combined + } +} + +/// A read-locked view of an ISA. Holds the graph read lock for the +/// lifetime of this struct, enabling zero-clone instruction access. +pub struct LockedISA<'a> { + graph: RwLockReadGuard<'a, ProvenanceGraph>, + nodes: &'a FxHashMap, +} + +impl LockedISA<'_> { + /// Returns a reference to the instruction with the given ID, if present. + #[must_use] + pub fn get(&self, id: &u64) -> Option<&Instruction> { + let &node_idx = self.nodes.get(id)?; + Some(self.graph.instruction(node_idx)) + } +} + +#[derive(Default)] +pub struct ISARequirements { + constraints: FxHashMap, +} + +impl ISARequirements { + #[must_use] + pub fn new() -> Self { + ISARequirements { + constraints: FxHashMap::default(), + } + } + + pub fn add_constraint(&mut self, constraint: InstructionConstraint) { + self.constraints.insert(constraint.id, constraint); + } + + #[must_use] + pub fn len(&self) -> usize { + self.constraints.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.constraints.is_empty() + } + + pub fn entry(&mut self, id: u64) -> Entry<'_, u64, InstructionConstraint> { + self.constraints.entry(id) + } + + /// Returns all instructions as owned clones. + #[must_use] + pub fn constraints(&self) -> Vec { + self.constraints.values().cloned().collect() + } +} + +impl FromIterator for ISARequirements { + fn from_iter>(iter: T) -> Self { + let mut reqs = ISARequirements::new(); + for constraint in iter { + reqs.add_constraint(constraint); + } + reqs + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct Instruction { + id: u64, + encoding: Encoding, + metrics: Metrics, + source: usize, + properties: Option>, +} + +impl Instruction { + #[must_use] + pub fn fixed_arity( + id: u64, + encoding: Encoding, + arity: u64, + time: u64, + space: Option, + length: Option, + error_rate: f64, + ) -> Self { + let length = length.unwrap_or(arity); + let space = space.unwrap_or(length); + + Instruction { + id, + encoding, + metrics: Metrics::FixedArity { + arity, + length, + space, + time, + error_rate, + }, + source: 0, + properties: None, + } + } + + #[must_use] + pub fn variable_arity( + id: u64, + encoding: Encoding, + time_fn: VariableArityFunction, + space_fn: VariableArityFunction, + length_fn: Option>, + error_rate_fn: VariableArityFunction, + ) -> Self { + let length_fn = length_fn.unwrap_or_else(|| space_fn.clone()); + + Instruction { + id, + encoding, + metrics: Metrics::VariableArity { + length_fn, + space_fn, + time_fn, + error_rate_fn, + }, + source: 0, + properties: None, + } + } + + #[must_use] + pub fn with_id(&self, id: u64) -> Self { + let mut new_instruction = self.clone(); + // reset source for new instruction + new_instruction.source = 0; + new_instruction.id = id; + new_instruction + } + + #[must_use] + pub fn id(&self) -> u64 { + self.id + } + + #[must_use] + pub fn encoding(&self) -> Encoding { + self.encoding + } + + #[must_use] + pub fn arity(&self) -> Option { + match &self.metrics { + Metrics::FixedArity { arity, .. } => Some(*arity), + Metrics::VariableArity { .. } => None, + } + } + + #[must_use] + pub fn space(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { space, .. } => Some(*space), + Metrics::VariableArity { space_fn, .. } => arity.map(|a| space_fn.evaluate(a)), + } + } + + #[must_use] + pub fn length(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { length, .. } => Some(*length), + Metrics::VariableArity { length_fn, .. } => arity.map(|a| length_fn.evaluate(a)), + } + } + + #[must_use] + pub fn time(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { time, .. } => Some(*time), + Metrics::VariableArity { time_fn, .. } => arity.map(|a| time_fn.evaluate(a)), + } + } + + #[must_use] + pub fn error_rate(&self, arity: Option) -> Option { + match &self.metrics { + Metrics::FixedArity { error_rate, .. } => Some(*error_rate), + Metrics::VariableArity { error_rate_fn, .. } => { + arity.map(|a| error_rate_fn.evaluate(a)) + } + } + } + + #[must_use] + pub fn expect_space(&self, arity: Option) -> u64 { + self.space(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_length(&self, arity: Option) -> u64 { + self.length(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_time(&self, arity: Option) -> u64 { + self.time(arity) + .expect("Instruction does not support variable arity") + } + + #[must_use] + pub fn expect_error_rate(&self, arity: Option) -> f64 { + self.error_rate(arity) + .expect("Instruction does not support variable arity") + } + + pub fn set_source(&mut self, provenance: usize) { + self.source = provenance; + } + + #[must_use] + pub fn source(&self) -> usize { + self.source + } + + pub fn set_property(&mut self, key: u64, value: u64) { + if let Some(ref mut properties) = self.properties { + properties.insert(key, value); + } else { + let mut properties = FxHashMap::default(); + properties.insert(key, value); + self.properties = Some(properties); + } + } + + #[must_use] + pub fn get_property(&self, key: &u64) -> Option { + self.properties.as_ref()?.get(key).copied() + } + + #[must_use] + pub fn has_property(&self, key: &u64) -> bool { + self.properties + .as_ref() + .is_some_and(|props| props.contains_key(key)) + } + + #[must_use] + pub fn get_property_or(&self, key: &u64, default: u64) -> u64 { + self.get_property(key).unwrap_or(default) + } +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = instruction_name(self.id).unwrap_or("??"); + match self.metrics { + Metrics::FixedArity { arity, .. } => { + write!(f, "{name} |{:?}| arity: {arity}", self.encoding) + } + Metrics::VariableArity { .. } => write!(f, "{name} |{:?}|", self.encoding), + } + } +} + +#[derive(Clone)] +pub struct InstructionConstraint { + id: u64, + encoding: Encoding, + arity: Option, + error_rate_fn: Option>, + properties: FxHashSet, +} + +impl InstructionConstraint { + #[must_use] + pub fn new( + id: u64, + encoding: Encoding, + arity: Option, + error_rate_fn: Option>, + ) -> Self { + InstructionConstraint { + id, + encoding, + arity, + error_rate_fn, + properties: FxHashSet::default(), + } + } + + /// Adds a property requirement to the constraint. + pub fn add_property(&mut self, property: u64) { + self.properties.insert(property); + } + + /// Checks if the constraint requires a specific property. + #[must_use] + pub fn has_property(&self, property: &u64) -> bool { + self.properties.contains(property) + } + + /// Returns the set of required properties. + #[must_use] + pub fn properties(&self) -> &FxHashSet { + &self.properties + } + + /// Returns the instruction ID this constraint applies to. + #[must_use] + pub fn id(&self) -> u64 { + self.id + } + + /// Returns the required encoding for this constraint. + #[must_use] + pub fn encoding(&self) -> Encoding { + self.encoding + } + + #[must_use] + pub fn arity(&self) -> Option { + self.arity + } + + pub fn set_arity(&mut self, arity: Option) { + self.arity = arity; + } + + #[must_use] + pub fn error_rate(&self) -> Option<&ConstraintBound> { + self.error_rate_fn.as_ref() + } + + pub fn set_error_rate(&mut self, error_rate_fn: Option>) { + self.error_rate_fn = error_rate_fn; + } + + /// Checks whether a given instruction satisfies this constraint. + #[must_use] + pub fn is_satisfied_by(&self, instruction: &Instruction) -> bool { + if instruction.encoding != self.encoding { + return false; + } + + match &instruction.metrics { + Metrics::FixedArity { + arity, error_rate, .. + } => { + // Constraint requires variable arity but instruction is fixed + let Some(constraint_arity) = self.arity else { + return false; + }; + if *arity != constraint_arity { + return false; + } + if let Some(ref bound) = self.error_rate_fn + && !bound.evaluate(error_rate) + { + return false; + } + } + Metrics::VariableArity { error_rate_fn, .. } => { + if let (Some(constraint_arity), Some(bound)) = (self.arity, &self.error_rate_fn) + && !bound.evaluate(&error_rate_fn.evaluate(constraint_arity)) + { + return false; + } + } + } + + for prop in &self.properties { + if !instruction.has_property(prop) { + return false; + } + } + + true + } +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Encoding { + #[default] + Physical, + Logical, +} + +#[derive(Clone, Serialize, Deserialize)] +pub enum Metrics { + FixedArity { + arity: u64, + length: u64, + space: u64, + time: u64, + error_rate: f64, + }, + VariableArity { + length_fn: VariableArityFunction, + space_fn: VariableArityFunction, + time_fn: VariableArityFunction, + error_rate_fn: VariableArityFunction, + }, +} + +#[derive(Clone, Serialize, Deserialize)] +pub enum VariableArityFunction { + Constant { + value: T, + }, + Linear { + slope: T, + }, + BlockLinear { + block_size: u64, + slope: T, + offset: T, + }, + #[serde(skip)] + Generic { + func: Arc T + Send + Sync>, + }, +} + +impl + std::ops::Mul + Copy + FromPrimitive> + VariableArityFunction +{ + pub fn constant(value: T) -> Self { + VariableArityFunction::Constant { value } + } + + pub fn linear(slope: T) -> Self { + VariableArityFunction::Linear { slope } + } + + pub fn block_linear(block_size: u64, slope: T, offset: T) -> Self { + VariableArityFunction::BlockLinear { + block_size, + slope, + offset, + } + } + + pub fn generic(func: impl Fn(u64) -> T + Send + Sync + 'static) -> Self { + VariableArityFunction::Generic { + func: Arc::new(func), + } + } + + pub fn generic_from_arc(func: Arc T + Send + Sync>) -> Self { + VariableArityFunction::Generic { func } + } + + pub fn evaluate(&self, arity: u64) -> T { + match self { + VariableArityFunction::Constant { value } => *value, + VariableArityFunction::Linear { slope } => { + *slope * T::from_u64(arity).expect("Failed to convert u64 to target type") + } + VariableArityFunction::BlockLinear { + block_size, + slope, + offset, + } => { + let blocks = arity.div_ceil(*block_size); + *slope * T::from_u64(blocks).expect("Failed to convert u64 to target type") + + *offset + } + VariableArityFunction::Generic { func } => func(arity), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ConstraintBound { + LessThan(T), + LessEqual(T), + Equal(T), + GreaterThan(T), + GreaterEqual(T), +} + +impl ConstraintBound { + pub fn less_than(value: T) -> Self { + ConstraintBound::LessThan(value) + } + + pub fn less_equal(value: T) -> Self { + ConstraintBound::LessEqual(value) + } + + pub fn equal(value: T) -> Self { + ConstraintBound::Equal(value) + } + + pub fn greater_than(value: T) -> Self { + ConstraintBound::GreaterThan(value) + } + + pub fn greater_equal(value: T) -> Self { + ConstraintBound::GreaterEqual(value) + } + + pub fn evaluate(&self, other: &T) -> bool { + match self { + ConstraintBound::LessThan(v) => other < v, + ConstraintBound::LessEqual(v) => other <= v, + ConstraintBound::Equal(v) => other == v, + ConstraintBound::GreaterThan(v) => other > v, + ConstraintBound::GreaterEqual(v) => other >= v, + } + } +} diff --git a/source/qre/src/isa/property_keys.rs b/source/qre/src/isa/property_keys.rs new file mode 100644 index 0000000000..16158e31db --- /dev/null +++ b/source/qre/src/isa/property_keys.rs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// NOTE: To add a new property key: +// 1. Add the name to the `define_properties!` macro below (values are auto-assigned) +// 2. Add it to `add_property_keys` in qre.rs +// 3. Add it to property_keys.pyi +// +// The `property_name_to_key` function is auto-generated from the entries. + +/// Property keys for instruction properties. These are used to query properties of instructions in the ISA. +macro_rules! define_properties { + // Internal rule: accumulator-based counting to auto-assign incrementing u64 values. + (@step $counter:expr, $name:ident, $($rest:ident),* $(,)?) => { + pub const $name: u64 = $counter; + define_properties!(@step $counter + 1, $($rest),*); + }; + (@step $counter:expr, $name:ident $(,)?) => { + pub const $name: u64 = $counter; + }; + // Entry point + ( $($name:ident),* $(,)? ) => { + define_properties!(@step 0, $($name),*); + + /// Property name → integer key mapping + pub fn property_name_to_key(name: &str) -> Option { + match name { + $( + stringify!($name) => Some($name), + )* + _ => None + } + } + + /// Integer key → property name mapping + #[must_use] + pub fn property_name(id: u64) -> Option<&'static str> { + match id { + $( + $name => Some(stringify!($name)), + )* + _ => None, + } + } + }; +} + +define_properties! { + DISTANCE, + SURFACE_CODE_ONE_QUBIT_TIME_FACTOR, + SURFACE_CODE_TWO_QUBIT_TIME_FACTOR, + ACCELERATION, + NUM_TS_PER_ROTATION, + RUNTIME_SINGLE_SHOT, + EXPECTED_SHOTS, + EVALUATION_TIME, + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, + PHYSICAL_MEMORY_QUBITS, + MOLECULE, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + NAME, +} diff --git a/source/qre/src/isa/provenance.rs b/source/qre/src/isa/provenance.rs new file mode 100644 index 0000000000..5f68ca180e --- /dev/null +++ b/source/qre/src/isa/provenance.rs @@ -0,0 +1,346 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::sync::{Arc, RwLock}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{Encoding, ISA, ISARequirements, Instruction, ParetoFrontier3D}; + +pub struct ProvenanceGraph { + nodes: Vec, + // A consecutive list of child node indices for each node, where the + // children of node i are located at children[offset..offset+num_children] + // in the children vector. + children: Vec, + // Per-instruction-ID index of Pareto-optimal node indices. + // Built by `build_pareto_index()` after all nodes have been added. + pareto_index: FxHashMap>, +} + +impl Default for ProvenanceGraph { + fn default() -> Self { + // Initialize with a dummy node at index 0 to simplify indexing logic + // (so that 0 can be used as a "null" provenance) + let empty = ProvenanceNode::default(); + ProvenanceGraph { + nodes: vec![empty], + children: Vec::new(), + pareto_index: FxHashMap::default(), + } + } +} + +/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. +struct InstructionParetoItem { + node_index: usize, + space: u64, + time: u64, + error: f64, +} + +impl crate::ParetoItem3D for InstructionParetoItem { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> u64 { + self.space + } + fn objective2(&self) -> u64 { + self.time + } + fn objective3(&self) -> f64 { + self.error + } +} + +impl ProvenanceGraph { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn add_node( + &mut self, + mut instruction: Instruction, + transform_id: u64, + children: &[usize], + ) -> usize { + let node_index = self.nodes.len(); + instruction.source = node_index; + let offset = self.children.len(); + let num_children = children.len(); + self.children.extend_from_slice(children); + self.nodes.push(ProvenanceNode { + instruction, + transform_id, + offset, + num_children, + }); + node_index + } + + #[must_use] + pub fn instruction(&self, node_index: usize) -> &Instruction { + &self.nodes[node_index].instruction + } + + #[must_use] + pub fn transform_id(&self, node_index: usize) -> u64 { + self.nodes[node_index].transform_id + } + + #[must_use] + pub fn children(&self, node_index: usize) -> &[usize] { + let node = &self.nodes[node_index]; + &self.children[node.offset..node.offset + node.num_children] + } + + #[must_use] + pub fn num_nodes(&self) -> usize { + self.nodes.len() - 1 + } + + #[must_use] + pub fn num_edges(&self) -> usize { + self.children.len() + } + + /// Builds the per-instruction-ID Pareto index. + /// + /// For each instruction ID in the graph, collects all nodes and retains + /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) + /// evaluated at arity 1. Instructions with different encodings or + /// properties are never in competition. + /// + /// Must be called after all nodes have been added. + pub fn build_pareto_index(&mut self) { + // Group node indices by (instruction_id, encoding, properties) + let mut groups: FxHashMap> = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let instr = &self.nodes[idx].instruction; + groups.entry(instr.id).or_default().push(idx); + } + + let mut pareto_index = FxHashMap::default(); + for (id, node_indices) in groups { + // Sub-partition by encoding and property keys to avoid comparing + // incompatible instructions (Risk R2 mitigation) + #[allow(clippy::type_complexity)] + let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = + FxHashMap::default(); + for &idx in &node_indices { + let instr = &self.nodes[idx].instruction; + let mut prop_vec: Vec<(u64, u64)> = instr + .properties + .as_ref() + .map(|p| { + let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); + v.sort_unstable(); + v + }) + .unwrap_or_default(); + prop_vec.sort_unstable(); + sub_groups + .entry((instr.encoding, prop_vec)) + .or_default() + .push(idx); + } + + let mut pareto_nodes = Vec::new(); + for (_key, indices) in sub_groups { + let items: Vec = indices + .iter() + .filter_map(|&idx| { + let instr = &self.nodes[idx].instruction; + let space = instr.space(Some(1))?; + let time = instr.time(Some(1))?; + let error = instr.error_rate(Some(1))?; + Some(InstructionParetoItem { + node_index: idx, + space, + time, + error, + }) + }) + .collect(); + + let frontier: ParetoFrontier3D = items.into_iter().collect(); + pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); + } + + pareto_index.insert(id, pareto_nodes); + } + + self.pareto_index = pareto_index; + } + + /// Returns the Pareto-optimal node indices for a given instruction ID. + #[must_use] + pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { + self.pareto_index.get(&instruction_id).map(Vec::as_slice) + } + + /// Returns all instruction IDs that have Pareto-optimal entries. + #[must_use] + pub fn pareto_instruction_ids(&self) -> Vec { + self.pareto_index.keys().copied().collect() + } + + /// Returns the raw node count (including the sentinel at index 0). + #[must_use] + pub fn raw_node_count(&self) -> usize { + self.nodes.len() + } + + /// Returns the total number of ISAs that can be formed from Pareto-optimal + /// nodes. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn total_isa_count(&self) -> usize { + self.pareto_index.values().map(Vec::len).product() + } + + /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given + /// requirements. + /// + /// For each constraint, selects matching Pareto-optimal nodes. Produces + /// the Cartesian product of per-constraint match sets, each augmented + /// with one representative node per unconstrained instruction ID (so + /// that returned ISAs contain entries for all instruction types in the + /// graph). + /// + /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n + /// are considered for constrained groups. Unconstrained "extra" nodes + /// are not filtered since they serve only as default placeholders. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn query_satisfying( + &self, + graph_arc: &Arc>, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> Vec { + let min_idx = min_node_idx.unwrap_or(0); + + let mut constrained_groups: Vec> = Vec::new(); + let mut constrained_ids: FxHashSet = FxHashSet::default(); + + for constraint in requirements.constraints.values() { + constrained_ids.insert(constraint.id()); + + // When a node range is specified, scan ALL nodes in the range + // instead of using the global Pareto index. The global index + // may have pruned nodes from this range as duplicates of + // earlier, equivalent nodes outside the range. + let matching: Vec<(u64, usize)> = if min_idx > 0 { + let mut m: Vec<(u64, usize)> = (min_idx..self.nodes.len()) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect(); + + // Fall back to the full graph for passthrough instructions + // that the source does not modify (e.g. architecture base + // gates that a wrapper leaves unchanged). + if m.is_empty() { + m = (1..min_idx) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect(); + } + m + } else { + let Some(pareto) = self.pareto_index.get(&constraint.id()) else { + return Vec::new(); + }; + pareto + .iter() + .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) + .map(|&node_idx| (constraint.id(), node_idx)) + .collect() + }; + + if matching.is_empty() { + return Vec::new(); + } + constrained_groups.push(matching); + } + + // One representative node per unconstrained instruction ID. + // When a Pareto index is available, use it; otherwise scan all + // nodes (this path is used during populate() before the index + // is built). + let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { + let mut seen: FxHashMap = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let id = self.nodes[idx].instruction.id; + if !constrained_ids.contains(&id) { + seen.entry(id).or_insert(idx); + } + } + seen.into_iter().collect() + } else { + self.pareto_index + .iter() + .filter(|(id, _)| !constrained_ids.contains(id)) + .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) + .collect() + }; + + // Cartesian product of constrained groups + let mut combinations: Vec> = vec![Vec::new()]; + for group in &constrained_groups { + let mut next = Vec::with_capacity(combinations.len() * group.len()); + for combo in &combinations { + for &item in group { + let mut extended = combo.clone(); + extended.push(item); + next.push(extended); + } + } + combinations = next; + } + + // Build ISAs from selections + combinations + .into_iter() + .map(|mut combo| { + combo.extend(extra_nodes.iter().copied()); + let mut isa = ISA::with_graph(Arc::clone(graph_arc)); + for (id, node_idx) in combo { + isa.add_node(id, node_idx); + } + isa + }) + .collect() + } +} + +struct ProvenanceNode { + instruction: Instruction, + transform_id: u64, + offset: usize, + num_children: usize, +} + +impl Default for ProvenanceNode { + fn default() -> Self { + ProvenanceNode { + instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), + transform_id: 0, + offset: 0, + num_children: 0, + } + } +} diff --git a/source/qre/src/isa/tests.rs b/source/qre/src/isa/tests.rs new file mode 100644 index 0000000000..b847a6b049 --- /dev/null +++ b/source/qre/src/isa/tests.rs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn test_fixed_arity_instruction() { + let instr = Instruction::fixed_arity(1, Encoding::Physical, 2, 100, Some(10), Some(5), 0.01); + + assert_eq!(instr.id(), 1); + assert_eq!(instr.encoding(), Encoding::Physical); + assert_eq!(instr.arity(), Some(2)); + assert_eq!(instr.time(None), Some(100)); + assert_eq!(instr.space(None), Some(10)); + assert_eq!(instr.length(None), Some(5)); + assert_eq!(instr.error_rate(None), Some(0.01)); +} + +#[test] +fn test_variable_arity_instruction() { + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::constant(0.001); + + let instr = + Instruction::variable_arity(2, Encoding::Logical, time_fn, space_fn, None, error_rate_fn); + + assert_eq!(instr.id(), 2); + assert_eq!(instr.encoding(), Encoding::Logical); + assert_eq!(instr.arity(), None); + + // Check evaluation at specific arity + assert_eq!(instr.time(Some(3)), Some(30)); // 3 * 10 + assert_eq!(instr.space(Some(3)), Some(5)); + assert_eq!(instr.length(Some(3)), Some(5)); // Defaulted to space_fn + assert_eq!(instr.error_rate(Some(3)), Some(0.001)); + + // Check None arity returns None for variable metrics + assert_eq!(instr.time(None), None); +} + +#[test] +fn test_isa_satisfies() { + let mut isa = ISA::new(); + let instr1 = Instruction::fixed_arity(1, Encoding::Physical, 2, 100, None, None, 0.01); + isa.add_instruction(instr1); + + let mut reqs = ISARequirements::new(); + + // Test exact match + reqs.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(2), + Some(ConstraintBound::less_than(0.02)), + )); + assert!(isa.satisfies(&reqs)); + + // Test failing error rate + let mut reqs_fail = ISARequirements::new(); + reqs_fail.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(2), + Some(ConstraintBound::less_than(0.005)), + )); + assert!(!isa.satisfies(&reqs_fail)); + + // Test failing arity + let mut reqs_fail_arity = ISARequirements::new(); + reqs_fail_arity.add_constraint(InstructionConstraint::new( + 1, + Encoding::Physical, + Some(3), + Some(ConstraintBound::less_than(0.02)), + )); + assert!(!isa.satisfies(&reqs_fail_arity)); + + // Test failing encoding + let mut reqs_fail_enc = ISARequirements::new(); + reqs_fail_enc.add_constraint(InstructionConstraint::new( + 1, + Encoding::Logical, + Some(2), + None, + )); + assert!(!isa.satisfies(&reqs_fail_enc)); + + // Test missing instruction + let mut reqs_missing = ISARequirements::new(); + reqs_missing.add_constraint(InstructionConstraint::new( + 99, + Encoding::Physical, + None, + None, + )); + assert!(!isa.satisfies(&reqs_missing)); +} + +#[test] +fn test_variable_arity_satisfies() { + let mut isa = ISA::new(); + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::linear(0.001); // 0.001 * arity + + let instr = Instruction::variable_arity( + 10, + Encoding::Logical, + time_fn, + space_fn, + None, + error_rate_fn, + ); + isa.add_instruction(instr); + + let mut reqs = ISARequirements::new(); + // Check for arity 5, error rate should be 0.005 + reqs.add_constraint(InstructionConstraint::new( + 10, + Encoding::Logical, + Some(5), + Some(ConstraintBound::less_than(0.01)), + )); + assert!(isa.satisfies(&reqs)); // 0.005 < 0.01 + + let mut reqs_fail = ISARequirements::new(); + // Check for arity 20, error rate should be 0.02 + reqs_fail.add_constraint(InstructionConstraint::new( + 10, + Encoding::Logical, + Some(20), + Some(ConstraintBound::less_than(0.01)), + )); + assert!(!isa.satisfies(&reqs_fail)); // 0.02 not < 0.01 +} + +#[test] +fn test_variable_arity_function() { + let linear_fn = VariableArityFunction::linear(10); + assert_eq!(linear_fn.evaluate(3), 30); + assert_eq!(linear_fn.evaluate(0), 0); + + let constant_fn = VariableArityFunction::constant(5); + assert_eq!(constant_fn.evaluate(3), 5); + assert_eq!(constant_fn.evaluate(0), 5); + + // Test with a custom function + let custom_fn = VariableArityFunction::generic(|arity| arity * arity); // Quadratic + assert_eq!(custom_fn.evaluate(3), 9); + assert_eq!(custom_fn.evaluate(4), 16); +} + +#[test] +fn test_instruction_display_known_id() { + use crate::trace::instruction_ids::H; + + let instr = Instruction::fixed_arity(H, Encoding::Physical, 1, 100, None, None, 0.01); + let display = format!("{instr}"); + + assert!(display.contains('H'), "Expected 'H' in '{display}'"); + assert!( + display.contains("arity: 1"), + "Expected 'arity: 1' in '{display}'" + ); +} + +#[test] +fn test_instruction_display_unknown_id() { + let unknown_id = 0x9999; + let instr = Instruction::fixed_arity(unknown_id, Encoding::Logical, 2, 50, None, None, 0.001); + let display = format!("{instr}"); + + assert!( + display.contains("??"), + "Expected '??' for unknown ID in '{display}'" + ); +} + +#[test] +fn test_instruction_display_variable_arity() { + use crate::trace::instruction_ids::MULTI_PAULI_MEAS; + + let time_fn = VariableArityFunction::linear(10); + let space_fn = VariableArityFunction::constant(5); + let error_rate_fn = VariableArityFunction::constant(0.001); + + let instr = Instruction::variable_arity( + MULTI_PAULI_MEAS, + Encoding::Logical, + time_fn, + space_fn, + None, + error_rate_fn, + ); + let display = format!("{instr}"); + + assert!( + display.contains("MULTI_PAULI_MEAS"), + "Expected 'MULTI_PAULI_MEAS' in '{display}'" + ); + // Variable arity instructions don't show arity + assert!( + !display.contains("arity:"), + "Variable arity should not show arity in '{display}'" + ); +} diff --git a/source/qre/src/lib.rs b/source/qre/src/lib.rs new file mode 100644 index 0000000000..42db079461 --- /dev/null +++ b/source/qre/src/lib.rs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use thiserror::Error; + +mod isa; +mod pareto; +pub use pareto::{ + ParetoFrontier as ParetoFrontier2D, ParetoFrontier3D, ParetoItem2D, ParetoItem3D, +}; +mod result; +pub use isa::property_keys; +pub use isa::property_keys::{property_name, property_name_to_key}; +pub use isa::{ + ConstraintBound, Encoding, ISA, ISARequirements, Instruction, InstructionConstraint, LockedISA, + ProvenanceGraph, VariableArityFunction, +}; +pub use result::{EstimationCollection, EstimationResult, FactoryResult, ResultSummary}; +mod trace; +pub use trace::instruction_ids; +pub use trace::instruction_ids::instruction_name; +pub use trace::{ + Block, LatticeSurgery, PSSPC, Property, Trace, TraceTransform, estimate_parallel, + estimate_with_graph, +}; +mod utils; +pub use utils::{binom_ppf, float_from_bits, float_to_bits}; + +/// A resourc estimation error. +#[derive(Clone, Debug, Error, PartialEq)] +pub enum Error { + /// The resource estimation exceeded the maximum allowed error. + #[error("resource estimation exceeded the maximum allowed error: {actual_error} > {max_error}")] + MaximumErrorExceeded { actual_error: f64, max_error: f64 }, + /// Missing instruction in the ISA. + #[error("requested instruction {0} not present in ISA")] + InstructionNotFound(u64), + /// Cannot extract space from instruction. + #[error("cannot extract space from instruction {0} for fixed arity")] + CannotExtractSpace(u64), + /// Cannot extract time from instruction. + #[error("cannot extract time from instruction {0} for fixed arity")] + CannotExtractTime(u64), + /// Cannot extract error rate from instruction. + #[error("cannot extract error rate from instruction {0} for fixed arity")] + CannotExtractErrorRate(u64), + /// Factory time exceeds algorithm runtime + #[error( + "factory instruction {id} time {factory_time} exceeds algorithm runtime {algorithm_runtime}" + )] + FactoryTimeExceedsAlgorithmRuntime { + id: u64, + factory_time: u64, + algorithm_runtime: u64, + }, + /// Unsupported instruction in trace transformation + #[error("unsupported instruction {id} in trace transformation '{name}'")] + UnsupportedInstruction { id: u64, name: &'static str }, +} diff --git a/source/qre/src/pareto.rs b/source/qre/src/pareto.rs new file mode 100644 index 0000000000..414faaad39 --- /dev/null +++ b/source/qre/src/pareto.rs @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use serde::{Deserialize, Serialize}; + +#[cfg(test)] +mod tests; + +pub trait ParetoItem2D { + type Objective1: PartialOrd + Copy; + type Objective2: PartialOrd + Copy; + + fn objective1(&self) -> Self::Objective1; + fn objective2(&self) -> Self::Objective2; +} + +pub trait ParetoItem3D { + type Objective1: PartialOrd + Copy; + type Objective2: PartialOrd + Copy; + type Objective3: PartialOrd + Copy; + + fn objective1(&self) -> Self::Objective1; + fn objective2(&self) -> Self::Objective2; + fn objective3(&self) -> Self::Objective3; +} + +/// A Pareto frontier for 2-dimensional objectives. +/// +/// The implementation maintains the frontier sorted by the first objective. +/// This allows for efficient updates based on the geometric property that +/// a point is dominated if and only if it is dominated by its immediate +/// predecessor in the sorted list (when sorted by the first objective). +/// +/// This approach is related to the algorithms described in: +/// H. T. Kung, F. Luccio, and F. P. Preparata, "On Finding the Maxima of a Set of Vectors," +/// Journal of the ACM, vol. 22, no. 4, pp. 469-476, 1975. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct ParetoFrontier(pub Vec); + +impl ParetoFrontier { + #[must_use] + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn insert(&mut self, p: I) { + // If any objective is incomparable (e.g. NaN), we silently ignore the item + // to maintain the frontier's sorting invariant. + if p.objective1().partial_cmp(&p.objective1()).is_none() + || p.objective2().partial_cmp(&p.objective2()).is_none() + { + return; + } + + let frontier = &mut self.0; + let search = frontier.binary_search_by(|q| { + q.objective1() + .partial_cmp(&p.objective1()) + .expect("objectives must be comparable") + }); + + let pos = match search { + Ok(i) => { + if frontier[i].objective2() <= p.objective2() { + return; + } + i + } + Err(i) => { + if i > 0 { + let left = &frontier[i - 1]; + if left.objective2() <= p.objective2() { + return; + } + } + i + } + }; + let i = pos; + while i < frontier.len() && frontier[i].objective2() >= p.objective2() { + frontier.remove(i); + } + frontier.insert(pos, p); + } + + pub fn iter(&self) -> std::slice::Iter<'_, I> { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, I> { + self.0.iter_mut() + } + + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Extend for ParetoFrontier { + fn extend>(&mut self, iter: T) { + for p in iter { + self.insert(p); + } + } +} + +impl FromIterator for ParetoFrontier { + fn from_iter>(iter: T) -> Self { + let mut frontier = Self::new(); + frontier.extend(iter); + frontier + } +} + +impl IntoIterator for ParetoFrontier { + type Item = I; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, I: ParetoItem2D> IntoIterator for &'a ParetoFrontier { + type Item = &'a I; + type IntoIter = std::slice::Iter<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'a, I: ParetoItem2D> IntoIterator for &'a mut ParetoFrontier { + type Item = &'a mut I; + type IntoIter = std::slice::IterMut<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + +/// A Pareto frontier for 3-dimensional objectives. +/// +/// The implementation maintains the frontier sorted lexicographically. +/// Unlike the 2D case where dominance checks are O(1) given the sorted order, +/// the 3D case requires checking the prefix or suffix to establish dominance, +/// though maintaining sorted order significantly reduces the search space. +/// +/// The theoretical O(N log N) bound for constructing the 3D frontier is established in: +/// H. T. Kung, F. Luccio, and F. P. Preparata, "On Finding the Maxima of a Set of Vectors," +/// Journal of the ACM, vol. 22, no. 4, pp. 469-476, 1975. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ParetoFrontier3D(pub Vec); + +impl ParetoFrontier3D { + #[must_use] + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn insert(&mut self, p: I) { + // If any objective is incomparable (e.g. NaN), we silently ignore the item. + if p.objective1().partial_cmp(&p.objective1()).is_none() + || p.objective2().partial_cmp(&p.objective2()).is_none() + || p.objective3().partial_cmp(&p.objective3()).is_none() + { + return; + } + + let frontier = &mut self.0; + + // Use lexicographical sort covering all objectives. + // This makes the binary search deterministic and ensures specific properties for prefix/suffix. + let Err(pos) = frontier.binary_search_by(|q| { + q.objective1() + .partial_cmp(&p.objective1()) + .expect("objectives must be comparable") + .then_with(|| { + q.objective2() + .partial_cmp(&p.objective2()) + .expect("objectives must be comparable") + }) + .then_with(|| { + q.objective3() + .partial_cmp(&p.objective3()) + .expect("objectives must be comparable") + }) + }) else { + return; + }; + + // 1. Check if dominated by any existing point in the prefix [0..pos]. + // Because the list is sorted lexicographically, any point `q` before `pos` + // satisfies `q.obj1 <= p.obj1` (often strictly less). + // Therefore, we only need to check if `q` is also better in obj2 and obj3. + for q in &frontier[..pos] { + if q.objective2() <= p.objective2() && q.objective3() <= p.objective3() { + return; + } + } + + // 2. Remove points dominated by the new point in the suffix [pos..]. + // Any point `q` at or after `pos` satisfies `p.obj1 <= q.obj1`. + // So `p` can only dominate `q` if `p` is better in obj2 and obj3. + let mut i = pos; + while i < frontier.len() { + let q = &frontier[i]; + if p.objective2() <= q.objective2() && p.objective3() <= q.objective3() { + frontier.remove(i); + } else { + i += 1; + } + } + + frontier.insert(pos, p); + } + + pub fn iter(&self) -> std::slice::Iter<'_, I> { + self.0.iter() + } + + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Extend for ParetoFrontier3D { + fn extend>(&mut self, iter: T) { + for p in iter { + self.insert(p); + } + } +} + +impl FromIterator for ParetoFrontier3D { + fn from_iter>(iter: T) -> Self { + let mut frontier = Self::new(); + frontier.extend(iter); + frontier + } +} + +impl IntoIterator for ParetoFrontier3D { + type Item = I; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, I: ParetoItem3D> IntoIterator for &'a ParetoFrontier3D { + type Item = &'a I; + type IntoIter = std::slice::Iter<'a, I>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} diff --git a/source/qre/src/pareto/tests.rs b/source/qre/src/pareto/tests.rs new file mode 100644 index 0000000000..eaf2539c33 --- /dev/null +++ b/source/qre/src/pareto/tests.rs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::{ + EstimationCollection, EstimationResult, + pareto::{ParetoFrontier, ParetoFrontier3D, ParetoItem2D, ParetoItem3D}, +}; + +struct Point2D { + x: f64, + y: f64, +} + +impl ParetoItem2D for Point2D { + type Objective1 = f64; + type Objective2 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.x + } + + fn objective2(&self) -> Self::Objective2 { + self.y + } +} + +#[test] +fn test_update_frontier() { + let mut frontier: ParetoFrontier = ParetoFrontier::new(); + let p1 = Point2D { x: 1.0, y: 5.0 }; + frontier.insert(p1); + assert_eq!(frontier.0.len(), 1); + let p2 = Point2D { x: 2.0, y: 4.0 }; + frontier.insert(p2); + assert_eq!(frontier.0.len(), 2); + let p3 = Point2D { x: 1.5, y: 6.0 }; + frontier.insert(p3); + assert_eq!(frontier.0.len(), 2); + let p4 = Point2D { x: 3.0, y: 3.0 }; + frontier.insert(p4); + assert_eq!(frontier.0.len(), 3); + let p5 = Point2D { x: 2.5, y: 2.0 }; + frontier.insert(p5); + assert_eq!(frontier.0.len(), 3); +} + +#[test] +fn test_iter_frontier() { + let mut frontier: ParetoFrontier = ParetoFrontier::new(); + frontier.insert(Point2D { x: 1.0, y: 5.0 }); + frontier.insert(Point2D { x: 2.0, y: 4.0 }); + + let mut iter = frontier.iter(); + let p = iter.next().expect("Has element"); + assert!((p.x - 1.0).abs() <= f64::EPSILON); + assert!((p.y - 5.0).abs() <= f64::EPSILON); + + let p = iter.next().expect("Has element"); + assert!((p.x - 2.0).abs() <= f64::EPSILON); + assert!((p.y - 4.0).abs() <= f64::EPSILON); + + assert!(iter.next().is_none()); + + // Test IntoIterator for &ParetoFrontier + for p in &frontier { + assert!(p.x > 0.0); + } +} + +#[derive(Clone, Copy, Debug)] +struct Point3D { + x: f64, + y: f64, + z: f64, +} + +impl ParetoItem3D for Point3D { + type Objective1 = f64; + type Objective2 = f64; + type Objective3 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.x + } + + fn objective2(&self) -> Self::Objective2 { + self.y + } + + fn objective3(&self) -> Self::Objective3 { + self.z + } +} + +#[test] +fn test_update_frontier_3d() { + let mut frontier: ParetoFrontier3D = ParetoFrontier3D::new(); + + // p1: 1, 5, 5 + let p1 = Point3D { + x: 1.0, + y: 5.0, + z: 5.0, + }; + frontier.insert(p1); + assert_eq!(frontier.0.len(), 1); + + // p2: 2, 6, 6 (dominated by p1) + let p2 = Point3D { + x: 2.0, + y: 6.0, + z: 6.0, + }; + frontier.insert(p2); + assert_eq!(frontier.0.len(), 1); + + // p3: 0.5, 6, 6 (not dominated, x makes it unique) + let p3 = Point3D { + x: 0.5, + y: 6.0, + z: 6.0, + }; + frontier.insert(p3); + assert_eq!(frontier.0.len(), 2); + + // p4: 1, 4, 4 (dominates p1, should remove p1 and add p4) + // p1 (1,5,5). p4 (1,4,4). p4 <= p1? 1<=1, 4<=5, 4<=5. Yes. + // p3 (0.5,6,6). p4 (1,4,4). p4 <= p3? 1<=0.5 False. + // Result: p1 removed, p4 added. p3 remains. + let p4 = Point3D { + x: 1.0, + y: 4.0, + z: 4.0, + }; + frontier.insert(p4); + assert_eq!(frontier.0.len(), 2); + + // Verify content (generic check, not order specific) + let points: Vec<(f64, f64, f64)> = frontier.iter().map(|p| (p.x, p.y, p.z)).collect(); + + // Should contain p3 and p4 + assert!( + points + .iter() + .any(|p| (p.0 - 0.5).abs() < 1e-9 && (p.1 - 6.0).abs() < 1e-9) + ); + assert!( + points + .iter() + .any(|p| (p.0 - 1.0).abs() < 1e-9 && (p.1 - 4.0).abs() < 1e-9) + ); +} + +#[test] +fn test_estimation_results() { + let mut result_worst = EstimationResult::new(); + result_worst.add_qubits(994_570); + result_worst.add_runtime(346_196_523_750); + + let mut result_mid = EstimationResult::new(); + result_mid.add_qubits(994_570); + result_mid.add_runtime(346_191_476_400); + + let mut result_best = EstimationResult::new(); + result_best.add_qubits(994_570); + result_best.add_runtime(346_181_381_700); + + let results = [result_worst, result_mid, result_best]; + let permutations = [ + [0, 1, 2], + [0, 2, 1], + [1, 0, 2], + [1, 2, 0], + [2, 0, 1], + [2, 1, 0], + ]; + + for p in permutations { + let mut frontier = EstimationCollection::new(); + frontier.insert(results[p[0]].clone()); + frontier.insert(results[p[1]].clone()); + frontier.insert(results[p[2]].clone()); + assert_eq!(frontier.len(), 1, "Failed for permutation {p:?}"); + + // Verify the retained item is the best one (index 2) + let item = frontier.iter().next().expect("has item"); + assert_eq!( + item.runtime(), + 346_181_381_700, + "Wrong item retained for permutation {p:?}", + ); + } +} + +#[test] +fn test_estimation_results_3d_permutations() { + // Check that 3D frontier handles strictly dominating points correctly + // even when first dimension is equal. + + // p_worst: (10, 100, 1000) + let p_worst = Point3D { + x: 10.0, + y: 100.0, + z: 1000.0, + }; + // p_mid: (10, 90, 1000) -> Dominates p_worst + let p_mid = Point3D { + x: 10.0, + y: 90.0, + z: 1000.0, + }; + // p_best: (10, 80, 1000) -> Dominates p_mid and p_worst + let p_best = Point3D { + x: 10.0, + y: 80.0, + z: 1000.0, + }; + + let results = [p_worst, p_mid, p_best]; + + let permutations = [ + [0, 1, 2], + [0, 2, 1], + [1, 0, 2], + [1, 2, 0], + [2, 0, 1], + [2, 1, 0], + ]; + + for p in permutations { + let mut frontier = ParetoFrontier3D::new(); + frontier.insert(results[p[0]]); + frontier.insert(results[p[1]]); + frontier.insert(results[p[2]]); + assert_eq!(frontier.len(), 1, "Failed for 3D permutation {p:?}"); + + let item = frontier.iter().next().expect("has item"); + assert!( + (item.y - 80.0).abs() < f64::EPSILON, + "Wrong item retained for 3D permutation {p:?}", + ); + } +} diff --git a/source/qre/src/result.rs b/source/qre/src/result.rs new file mode 100644 index 0000000000..208195531b --- /dev/null +++ b/source/qre/src/result.rs @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + fmt::Display, + ops::{Deref, DerefMut}, +}; + +use rustc_hash::FxHashMap; + +use crate::{ISA, ParetoFrontier2D, ParetoItem2D, Property}; + +#[derive(Clone, Default)] +pub struct EstimationResult { + qubits: u64, + runtime: u64, + error: f64, + factories: FxHashMap, + isa: ISA, + isa_index: Option, + trace_index: Option, + properties: FxHashMap, +} + +impl EstimationResult { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn qubits(&self) -> u64 { + self.qubits + } + + #[must_use] + pub fn runtime(&self) -> u64 { + self.runtime + } + + #[must_use] + pub fn error(&self) -> f64 { + self.error + } + + #[must_use] + pub fn factories(&self) -> &FxHashMap { + &self.factories + } + + pub fn set_qubits(&mut self, qubits: u64) { + self.qubits = qubits; + } + + pub fn set_runtime(&mut self, runtime: u64) { + self.runtime = runtime; + } + + pub fn set_error(&mut self, error: f64) { + self.error = error; + } + + /// Adds to the current qubit count and returns the new value. + pub fn add_qubits(&mut self, qubits: u64) -> u64 { + self.qubits += qubits; + self.qubits + } + + /// Adds to the current runtime and returns the new value. + pub fn add_runtime(&mut self, runtime: u64) -> u64 { + self.runtime += runtime; + self.runtime + } + + /// Adds to the current error and returns the new value. + pub fn add_error(&mut self, error: f64) -> f64 { + self.error += error; + self.error + } + + pub fn add_factory_result(&mut self, id: u64, result: FactoryResult) { + self.factories.insert(id, result); + } + + pub fn set_isa(&mut self, isa: ISA) { + self.isa = isa; + } + + #[must_use] + pub fn isa(&self) -> &ISA { + &self.isa + } + + pub fn set_isa_index(&mut self, index: usize) { + self.isa_index = Some(index); + } + + #[must_use] + pub fn isa_index(&self) -> Option { + self.isa_index + } + + pub fn set_trace_index(&mut self, index: usize) { + self.trace_index = Some(index); + } + + #[must_use] + pub fn trace_index(&self) -> Option { + self.trace_index + } + + pub fn set_property(&mut self, key: u64, value: Property) { + self.properties.insert(key, value); + } + + #[must_use] + pub fn get_property(&self, key: u64) -> Option<&Property> { + self.properties.get(&key) + } + + #[must_use] + pub fn has_property(&self, key: u64) -> bool { + self.properties.contains_key(&key) + } + + #[must_use] + pub fn properties(&self) -> &FxHashMap { + &self.properties + } +} + +impl Display for EstimationResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Qubits: {}, Runtime: {}, Error: {}", + self.qubits, self.runtime, self.error + )?; + + if !self.factories.is_empty() { + for (id, factory) in &self.factories { + write!( + f, + ", {id}: {} runs x {} copies", + factory.runs(), + factory.copies() + )?; + } + } + + Ok(()) + } +} + +impl ParetoItem2D for EstimationResult { + type Objective1 = u64; // qubits + type Objective2 = u64; // runtime + + fn objective1(&self) -> Self::Objective1 { + self.qubits + } + + fn objective2(&self) -> Self::Objective2 { + self.runtime + } +} + +/// Lightweight summary of a successful estimation, used to identify +/// post-processing candidates without storing full results. +#[derive(Clone, Copy)] +pub struct ResultSummary { + pub trace_index: usize, + pub isa_index: usize, + pub qubits: u64, + pub runtime: u64, +} + +#[derive(Default)] +pub struct EstimationCollection { + frontier: ParetoFrontier2D, + /// Lightweight summaries of ALL successful estimates (not just Pareto). + all_summaries: Vec, + total_jobs: usize, + successful_estimates: usize, + isas: Vec, +} + +impl EstimationCollection { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn total_jobs(&self) -> usize { + self.total_jobs + } + + pub fn set_total_jobs(&mut self, total_jobs: usize) { + self.total_jobs = total_jobs; + } + + #[must_use] + pub fn successful_estimates(&self) -> usize { + self.successful_estimates + } + + pub fn set_successful_estimates(&mut self, successful_estimates: usize) { + self.successful_estimates = successful_estimates; + } + + pub fn push_summary(&mut self, summary: ResultSummary) { + self.all_summaries.push(summary); + } + + #[must_use] + pub fn all_summaries(&self) -> &[ResultSummary] { + &self.all_summaries + } + + pub fn push_isa(&mut self, isa: ISA) -> usize { + self.isas.push(isa); + self.isas.len() - 1 + } + + pub fn set_isas(&mut self, isas: Vec) { + self.isas = isas; + } + + #[must_use] + pub fn isas(&self) -> &[ISA] { + &self.isas + } +} + +impl Deref for EstimationCollection { + type Target = ParetoFrontier2D; + + fn deref(&self) -> &Self::Target { + &self.frontier + } +} + +impl DerefMut for EstimationCollection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.frontier + } +} + +#[derive(Clone)] +pub struct FactoryResult { + copies: u64, + runs: u64, + states: u64, + error_rate: f64, +} + +impl FactoryResult { + #[must_use] + pub fn new(copies: u64, runs: u64, states: u64, error_rate: f64) -> Self { + Self { + copies, + runs, + states, + error_rate, + } + } + + #[must_use] + pub fn copies(&self) -> u64 { + self.copies + } + + #[must_use] + pub fn runs(&self) -> u64 { + self.runs + } + + #[must_use] + pub fn states(&self) -> u64 { + self.states + } + + #[must_use] + pub fn error_rate(&self) -> f64 { + self.error_rate + } +} diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs new file mode 100644 index 0000000000..0e2d1ef106 --- /dev/null +++ b/source/qre/src/trace.rs @@ -0,0 +1,753 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + fmt::{Display, Formatter}, + vec, +}; + +use rustc_hash::{FxHashMap, FxHashSet}; +use serde::{Deserialize, Serialize}; + +use crate::{ + ConstraintBound, Encoding, Error, EstimationResult, FactoryResult, ISA, ISARequirements, + Instruction, InstructionConstraint, LockedISA, + property_keys::{ + LOGICAL_COMPUTE_QUBITS, LOGICAL_MEMORY_QUBITS, PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, + }, +}; + +mod estimation; +pub use estimation::{estimate_parallel, estimate_with_graph}; + +pub mod instruction_ids; +use instruction_ids::instruction_name; +#[cfg(test)] +mod tests; + +mod transforms; +pub use transforms::{LatticeSurgery, PSSPC, TraceTransform}; + +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct Trace { + block: Block, + base_error: f64, + compute_qubits: u64, + memory_qubits: Option, + resource_states: Option>, + properties: FxHashMap, +} + +impl Trace { + #[must_use] + pub fn new(compute_qubits: u64) -> Self { + Self { + compute_qubits, + ..Default::default() + } + } + + #[must_use] + pub fn clone_empty(&self, compute_qubits: Option) -> Self { + Self { + block: Block::default(), + base_error: self.base_error, + compute_qubits: compute_qubits.unwrap_or(self.compute_qubits), + memory_qubits: self.memory_qubits, + resource_states: self.resource_states.clone(), + properties: self.properties.clone(), + } + } + + #[must_use] + pub fn compute_qubits(&self) -> u64 { + self.compute_qubits + } + + pub fn set_compute_qubits(&mut self, compute_qubits: u64) { + self.compute_qubits = compute_qubits; + } + + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.block.add_operation(id, qubits, params); + } + + pub fn add_block(&mut self, repetitions: u64) -> &mut Block { + self.block.add_block(repetitions) + } + + pub fn root_block_mut(&mut self) -> &mut Block { + &mut self.block + } + + #[must_use] + pub fn base_error(&self) -> f64 { + self.base_error + } + + pub fn increment_base_error(&mut self, amount: f64) { + self.base_error += amount; + } + + #[must_use] + pub fn memory_qubits(&self) -> Option { + self.memory_qubits + } + + #[must_use] + pub fn has_memory_qubits(&self) -> bool { + self.memory_qubits.is_some() + } + + pub fn set_memory_qubits(&mut self, qubits: u64) { + self.memory_qubits = Some(qubits); + } + + pub fn increment_memory_qubits(&mut self, amount: u64) { + if amount == 0 { + return; + } + let current = self.memory_qubits.get_or_insert(0); + *current += amount; + } + + #[must_use] + pub fn total_qubits(&self) -> u64 { + self.compute_qubits + self.memory_qubits.unwrap_or(0) + } + + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { + if amount == 0 { + return; + } + let states = self.resource_states.get_or_insert_with(FxHashMap::default); + *states.entry(resource_id).or_default() += amount; + } + + #[must_use] + pub fn get_resource_states(&self) -> Option<&FxHashMap> { + self.resource_states.as_ref() + } + + #[must_use] + pub fn get_resource_state_count(&self, resource_id: u64) -> u64 { + if let Some(states) = &self.resource_states + && let Some(count) = states.get(&resource_id) + { + return *count; + } + 0 + } + + pub fn set_property(&mut self, key: u64, value: Property) { + self.properties.insert(key, value); + } + + #[must_use] + pub fn get_property(&self, key: u64) -> Option<&Property> { + self.properties.get(&key) + } + + #[must_use] + pub fn has_property(&self, key: u64) -> bool { + self.properties.contains_key(&key) + } + + #[must_use] + pub fn deep_iter(&self) -> TraceIterator<'_> { + TraceIterator::new(&self.block) + } + + /// Returns the set of instruction IDs required by this trace, along with + /// their arity constraints if available. We take the actual arity from the + /// instruction, and if we see instructions with the same ID but different + /// arities, we mark them as variable arity in the returned requirements. + /// If `max_error` is provided, also adds error rate constraints based on + /// the instruction usage volume and the maximum allowed error. These error + /// rate constraints can be used for instruction pruning during estimation. + #[allow(clippy::cast_precision_loss)] + #[must_use] + pub fn required_instruction_ids(&self, max_error: Option) -> ISARequirements { + let mut constraints = FxHashMap::::default(); + + let mut update_constraints = |id: u64, arity: u64, added_volume: u64| { + constraints + .entry(id) + .and_modify(|(constraint, volume)| { + if let Some(prev_arity) = constraint.arity() + && prev_arity != arity + { + constraint.set_arity(None); + } + *volume += added_volume; + }) + .or_insert({ + let constraint = + InstructionConstraint::new(id, Encoding::Logical, Some(arity), None); + (constraint, added_volume) + }); + }; + + for (gate, mult) in self.deep_iter() { + let arity = gate.qubits.len() as u64; + update_constraints(gate.id, arity, mult * arity); + } + if let Some(ref rs) = self.resource_states { + for (res_id, count) in rs { + update_constraints(*res_id, 1, *count); + } + } + if let Some(memory_qubits) = self.memory_qubits { + update_constraints(instruction_ids::MEMORY, memory_qubits, memory_qubits); + } + + if let Some(max_error) = max_error { + constraints + .into_values() + .map(|(mut c, volume)| { + c.set_error_rate(Some(ConstraintBound::less_equal( + max_error / (volume as f64), + ))); + c + }) + .collect() + } else { + constraints.into_values().map(|(c, _)| c).collect() + } + } + + #[must_use] + pub fn depth(&self) -> u64 { + self.block.depth() + } + + #[must_use] + pub fn num_gates(&self) -> u64 { + self.deep_iter().map(|(_, m)| m).sum() + } + + #[allow( + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::too_many_lines + )] + pub fn estimate(&self, isa: &ISA, max_error: Option) -> Result { + let locked = isa.lock(); + let max_error = max_error.unwrap_or(1.0); + + if self.base_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error: self.base_error, + max_error, + }); + } + + let mut result = EstimationResult::new(); + + // base error starts with the error already present in the trace + result.add_error(self.base_error); + + // Counts how many magic state factories are needed per resource state ID + let mut factories: FxHashMap = FxHashMap::default(); + + // This will track the number of physical qubits per logical qubit while + // processing all the instructions. Normally, we assume that the number + // is always the same. + let mut qubit_counts: Vec = vec![]; + + // ------------------------------------------------------------------ + // Add errors from resource states. Allow callable error rates. + // ------------------------------------------------------------------ + if let Some(resource_states) = &self.resource_states { + for (state_id, count) in resource_states { + let rate = get_error_rate_by_id(&locked, *state_id)?; + let actual_error = result.add_error(rate * (*count as f64)); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + factories.insert(*state_id, *count); + } + } + + // ------------------------------------------------------------------ + // Gate error accumulation using recursion over block structure. + // Each block contributes repetitions * internal_gate_errors. + // Missing instructions raise an error. Callable rates use arity. + // ------------------------------------------------------------------ + for (gate, mult) in self.deep_iter() { + let instr = get_instruction(&locked, gate.id)?; + + let arity = gate.qubits.len() as u64; + + let rate = instr.expect_error_rate(Some(arity)); + + let qubit_count = instr.expect_space(Some(arity)) as f64 / arity as f64; + + if let Err(i) = qubit_counts.binary_search_by(|qc| qc.total_cmp(&qubit_count)) { + qubit_counts.insert(i, qubit_count); + } + + let actual_error = result.add_error(rate * (mult as f64)); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + } + + let total_compute_qubits = (self.compute_qubits() as f64 + * qubit_counts.last().copied().unwrap_or(1.0)) + .ceil() as u64; + result.add_qubits(total_compute_qubits); + result.set_property( + PHYSICAL_COMPUTE_QUBITS, + Property::Int(total_compute_qubits.cast_signed()), + ); + + result.add_runtime( + self.block + .depth_and_used(Some(&|op: &Gate| { + let instr = get_instruction(&locked, op.id)?; + Ok(instr.expect_time(Some(op.qubits.len() as u64))) + }))? + .0, + ); + + // ------------------------------------------------------------------ + // Factory overhead estimation. Each factory produces states at + // a certain rate, so we need enough copies to meet the demand. + // ------------------------------------------------------------------ + let mut total_factory_qubits = 0; + for (factory, count) in &factories { + let instr = get_instruction(&locked, *factory)?; + let factory_time = get_time(instr)?; + let factory_space = get_space(instr)?; + let factory_error_rate = get_error_rate(instr)?; + let runs = result.runtime() / factory_time; + + if runs == 0 { + return Err(Error::FactoryTimeExceedsAlgorithmRuntime { + id: *factory, + factory_time, + algorithm_runtime: result.runtime(), + }); + } + + let copies = count.div_ceil(runs); + + total_factory_qubits += copies * factory_space; + result.add_factory_result( + *factory, + FactoryResult::new(copies, runs, *count, factory_error_rate), + ); + } + result.add_qubits(total_factory_qubits); + result.set_property( + PHYSICAL_FACTORY_QUBITS, + Property::Int(total_factory_qubits.cast_signed()), + ); + + // Memory qubits + if let Some(memory_qubits) = self.memory_qubits { + // We need a MEMORY instruction in our ISA + let memory = locked + .get(&instruction_ids::MEMORY) + .ok_or(Error::InstructionNotFound(instruction_ids::MEMORY))?; + + let memory_space = memory.expect_space(Some(memory_qubits)); + result.add_qubits(memory_space); + result.set_property( + PHYSICAL_MEMORY_QUBITS, + Property::Int(memory_space.cast_signed()), + ); + + // The number of rounds for the memory qubits to stay alive with + // respect to the total runtime of the algorithm. + let rounds = result + .runtime() + .div_ceil(memory.expect_time(Some(memory_qubits))); + + let actual_error = + result.add_error(rounds as f64 * memory.expect_error_rate(Some(memory_qubits))); + if actual_error > max_error { + return Err(Error::MaximumErrorExceeded { + actual_error, + max_error, + }); + } + } + + // Make main trace metrics properties to access them from the result + result.set_property( + LOGICAL_COMPUTE_QUBITS, + Property::Int(self.compute_qubits.cast_signed()), + ); + result.set_property( + LOGICAL_MEMORY_QUBITS, + Property::Int(self.memory_qubits.unwrap_or(0).cast_signed()), + ); + + // Copy properties from the trace to the result + for (key, value) in &self.properties { + result.set_property(*key, value.clone()); + } + + Ok(result) + } +} + +impl Display for Trace { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "@compute_qubits({})", self.compute_qubits())?; + + if let Some(memory_qubits) = self.memory_qubits { + writeln!(f, "@memory_qubits({memory_qubits})")?; + } + if self.base_error > 0.0 { + writeln!(f, "@base_error({})", self.base_error)?; + } + if let Some(resource_states) = &self.resource_states { + for (res_id, amount) in resource_states { + writeln!( + f, + "@resource_state({}, {amount})", + instruction_name(*res_id).unwrap_or("??") + )?; + } + } + write!(f, "{}", self.block) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Operation { + GateOperation(Gate), + BlockOperation(Block), +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Gate { + id: u64, + qubits: Vec, + params: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Block { + operations: Vec, + repetitions: u64, +} + +impl Default for Block { + fn default() -> Self { + Self { + operations: Vec::new(), + repetitions: 1, + } + } +} + +impl Block { + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.operations + .push(Operation::gate_operation(id, qubits, params)); + } + + pub fn add_block(&mut self, repetitions: u64) -> &mut Block { + self.operations + .push(Operation::block_operation(repetitions)); + + match self.operations.last_mut() { + Some(Operation::BlockOperation(b)) => b, + _ => unreachable!("Last operation must be a block operation"), + } + } + + pub fn write(&self, f: &mut Formatter<'_>, indent: usize) -> std::fmt::Result { + let indent_str = " ".repeat(indent); + if self.repetitions == 1 { + writeln!(f, "{indent_str}{{")?; + } else { + writeln!(f, "{indent_str}repeat {} {{", self.repetitions)?; + } + + for op in &self.operations { + match op { + Operation::GateOperation(Gate { id, qubits, params }) => { + let name = instruction_name(*id).unwrap_or("??"); + write!(f, "{indent_str} {name}")?; + if !params.is_empty() { + write!( + f, + "({})", + params + .iter() + .map(f64::to_string) + .collect::>() + .join(", ") + )?; + } + writeln!( + f, + "({})", + qubits + .iter() + .map(u64::to_string) + .collect::>() + .join(", ") + )?; + } + Operation::BlockOperation(b) => { + b.write(f, indent + 2)?; + } + } + } + writeln!(f, "{indent_str}}}") + } + + fn depth_and_used Result>( + &self, + duration_fn: Option<&FnDuration>, + ) -> Result<(u64, FxHashSet), Error> { + let mut qubit_depths: FxHashMap = FxHashMap::default(); + let mut all_used = FxHashSet::default(); + + for op in &self.operations { + match op { + Operation::GateOperation(gate) => { + let start_time = gate + .qubits + .iter() + .filter_map(|q| qubit_depths.get(q)) + .max() + .copied() + .unwrap_or(0); + + let duration = match duration_fn { + Some(f) => f(gate)?, + _ => 1, + }; + + let end_time = start_time + duration; + for q in &gate.qubits { + qubit_depths.insert(*q, end_time); + all_used.insert(*q); + } + } + Operation::BlockOperation(block) => { + let (duration, used) = block.depth_and_used(duration_fn)?; + if used.is_empty() { + continue; + } + + let start_time = used + .iter() + .filter_map(|q| qubit_depths.get(q)) + .max() + .copied() + .unwrap_or(0); + + let end_time = start_time + duration; + for q in &used { + qubit_depths.insert(*q, end_time); + } + all_used.extend(used); + } + } + } + + let max_depth = qubit_depths.values().max().copied().unwrap_or(0); + Ok((max_depth * self.repetitions, all_used)) + } + + #[must_use] + pub fn depth(&self) -> u64 { + self.depth_and_used:: Result>(None) + .expect("Duration function is None") + .0 + } +} + +impl Display for Block { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.write(f, 0) + } +} + +impl Operation { + fn gate_operation(id: u64, qubits: Vec, params: Vec) -> Self { + Operation::GateOperation(Gate { id, qubits, params }) + } + + fn block_operation(repetitions: u64) -> Self { + Operation::BlockOperation(Block { + operations: Vec::new(), + repetitions, + }) + } +} + +pub struct TraceIterator<'a> { + stack: Vec<(std::slice::Iter<'a, Operation>, u64)>, +} + +impl<'a> TraceIterator<'a> { + fn new(block: &'a Block) -> Self { + Self { + stack: vec![(block.operations.iter(), 1)], + } + } +} + +impl<'a> Iterator for TraceIterator<'a> { + type Item = (&'a Gate, u64); + + fn next(&mut self) -> Option { + loop { + let (iter, multiplier) = self.stack.last_mut()?; + match iter.next() { + Some(op) => match op { + Operation::GateOperation(g) => return Some((g, *multiplier)), + Operation::BlockOperation(block) => { + let new_multiplier = *multiplier * block.repetitions; + self.stack.push((block.operations.iter(), new_multiplier)); + } + }, + _ => { + self.stack.pop(); + } + } + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub enum Property { + Bool(bool), + Int(i64), + Float(f64), + Str(String), +} + +impl Property { + #[must_use] + pub fn new_bool(b: bool) -> Self { + Property::Bool(b) + } + + #[must_use] + pub fn new_int(i: i64) -> Self { + Property::Int(i) + } + + #[must_use] + pub fn new_float(f: f64) -> Self { + Property::Float(f) + } + + #[must_use] + pub fn new_str(s: String) -> Self { + Property::Str(s) + } + + #[must_use] + pub fn as_bool(&self) -> Option { + match self { + Property::Bool(b) => Some(*b), + _ => None, + } + } + + #[must_use] + pub fn as_int(&self) -> Option { + match self { + Property::Int(i) => Some(*i), + _ => None, + } + } + + #[must_use] + pub fn as_float(&self) -> Option { + match self { + Property::Float(f) => Some(*f), + _ => None, + } + } + + #[must_use] + pub fn as_str(&self) -> Option<&str> { + match self { + Property::Str(s) => Some(s), + _ => None, + } + } + + #[must_use] + pub fn is_bool(&self) -> bool { + matches!(self, Property::Bool(_)) + } + + #[must_use] + pub fn is_int(&self) -> bool { + matches!(self, Property::Int(_)) + } + + #[must_use] + pub fn is_float(&self) -> bool { + matches!(self, Property::Float(_)) + } + + #[must_use] + pub fn is_str(&self) -> bool { + matches!(self, Property::Str(_)) + } +} + +impl Display for Property { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Property::Bool(b) => write!(f, "{b}"), + Property::Int(i) => write!(f, "{i}"), + Property::Float(fl) => write!(f, "{fl}"), + Property::Str(s) => write!(f, "{s}"), + } + } +} + +// Some helper functions to extract instructions and their metrics together with +// error handling + +fn get_instruction<'a>(isa: &'a LockedISA<'_>, id: u64) -> Result<&'a Instruction, Error> { + isa.get(&id).ok_or(Error::InstructionNotFound(id)) +} + +fn get_space(instruction: &Instruction) -> Result { + instruction + .space(None) + .ok_or(Error::CannotExtractSpace(instruction.id())) +} + +fn get_time(instruction: &Instruction) -> Result { + instruction + .time(None) + .ok_or(Error::CannotExtractTime(instruction.id())) +} + +fn get_error_rate(instruction: &Instruction) -> Result { + instruction + .error_rate(None) + .ok_or(Error::CannotExtractErrorRate(instruction.id())) +} + +fn get_error_rate_by_id(isa: &LockedISA<'_>, id: u64) -> Result { + let instr = get_instruction(isa, id)?; + instr + .error_rate(None) + .ok_or(Error::CannotExtractErrorRate(id)) +} diff --git a/source/qre/src/trace/estimation.rs b/source/qre/src/trace/estimation.rs new file mode 100644 index 0000000000..3bbc7dd25d --- /dev/null +++ b/source/qre/src/trace/estimation.rs @@ -0,0 +1,504 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + iter::repeat_with, + sync::{Arc, RwLock, atomic::AtomicUsize}, +}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{EstimationCollection, ISA, ProvenanceGraph, ResultSummary, Trace}; + +/// Estimates all (trace, ISA) combinations in parallel, returning only the +/// successful results collected into an [`EstimationCollection`]. +/// +/// This uses a shared atomic counter as a lock-free work queue. Each worker +/// thread atomically claims the next job index, maps it to a `(trace, isa)` +/// pair, and runs the estimation. This keeps all available cores busy until +/// the last job completes. +/// +/// # Work distribution +/// +/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: +/// - `trace_idx = j / isas.len()` +/// - `isa_idx = j % isas.len()` +/// +/// Each worker accumulates results locally and sends them back over a bounded +/// channel once it runs out of work, avoiding contention on the shared +/// collection. +#[must_use] +pub fn estimate_parallel<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let total_jobs = traces.len() * isas.len(); + let num_isas = isas.len(); + + // Shared atomic counter acts as a lock-free work queue. Workers call + // fetch_add to claim the next job index. + let next_job = AtomicUsize::new(0); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + // Bounded channel so each worker can send its batch of results back + // to the main thread without unbounded buffering. + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + // Atomically claim the next job. Relaxed ordering is + // sufficient because there is no dependent data between + // jobs — each (trace, isa) pair is independent. + let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job >= total_jobs { + break; + } + + // Map the flat job index to a (trace, ISA) pair. + let trace_idx = job / num_isas; + let isa_idx = job % num_isas; + + if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) + { + estimation.set_isa_index(isa_idx); + estimation.set_trace_index(trace_idx); + + local_results.push(estimation); + } + } + // Send all results from this worker in one batch. + let _ = tx.send(local_results); + }); + } + // Drop the cloned sender so the receiver iterator terminates once all + // workers have finished. + drop(tx); + + // Collect results from all workers into the shared collection. + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isas[idx].clone()); + } + } + + collection +} + +/// A node in the provenance graph along with pre-computed (space, time) values +/// for pruning. +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +struct NodeProfile { + node_index: usize, + space: u64, + time: u64, +} + +/// A single entry in a combination of instruction choices for estimation. +#[derive(Clone, Copy, Hash, Eq, PartialEq)] +struct CombinationEntry { + instruction_id: u64, + node: NodeProfile, +} + +/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` +/// pairs observed in successful estimations. +type SlotWitnesses = RwLock>>; + +/// Computes a hash of the combination context (all slots except the excluded +/// one). Two combinations that agree on every slot except `exclude_idx` +/// produce the same context hash. +fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { + let mut hasher = DefaultHasher::new(); + for (i, entry) in combination.iter().enumerate() { + if i != exclude_idx { + entry.instruction_id.hash(&mut hasher); + entry.node.node_index.hash(&mut hasher); + } + } + hasher.finish() +} + +/// Checks whether a combination is dominated by a previously successful one. +/// +/// A combination is prunable if, for any instruction slot, there exists a +/// successful combination with the same instructions in all other slots and +/// an instruction at that slot with `space <=` and `time <=`. +/// +/// However, for instruction slots that affect the algorithm runtime (gate +/// instructions) when the trace has factory instructions (resource states), +/// time-based dominance is **unsound**. A faster gate instruction reduces +/// the algorithm runtime, which reduces the number of factory runs per copy, +/// requiring more factory copies and potentially more total qubits. The +/// "dominated" combination (with a slower gate) can therefore produce a +/// Pareto-optimal result with fewer qubits but more runtime. +/// +/// When `runtime_affecting_ids` is non-empty, slots whose instruction ID +/// appears in that set are skipped entirely during the dominance check. +fn is_dominated( + combination: &[CombinationEntry], + trace_pruning: &[SlotWitnesses], + runtime_affecting_ids: &FxHashSet, +) -> bool { + for (slot_idx, entry) in combination.iter().enumerate() { + // Skip dominance check for runtime-affecting slots when factories + // exist, because shorter gate time can increase factory overhead. + if runtime_affecting_ids.contains(&entry.instruction_id) { + continue; + } + let ctx_hash = combination_context_hash(combination, slot_idx); + let map = trace_pruning[slot_idx] + .read() + .expect("Pruning lock poisoned"); + if map.get(&ctx_hash).is_some_and(|w| { + w.iter() + .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) + }) { + return true; + } + } + false +} + +/// Records a successful estimation as a pruning witness for future +/// combinations. +fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let mut map = trace_pruning[slot_idx] + .write() + .expect("Pruning lock poisoned"); + map.entry(ctx_hash) + .or_default() + .push((entry.node.space, entry.node.time)); + } +} + +#[derive(Default)] +struct ISAIndex { + index: FxHashMap, usize>, + isas: Vec, +} + +impl From for Vec { + fn from(value: ISAIndex) -> Self { + value.isas + } +} + +impl ISAIndex { + pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { + if let Some(&idx) = self.index.get(combination) { + idx + } else { + let idx = self.isas.len(); + self.isas.push(isa.clone()); + self.index.insert(combination.clone(), idx); + idx + } + } +} + +/// Generates the cartesian product of `id_and_nodes` and pushes each +/// combination directly into `jobs`, avoiding intermediate allocations. +/// +/// The cartesian product is enumerated using mixed-radix indexing. Given +/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations +/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely +/// identifies one element from every dimension: the index into dimension `d` is +/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by +/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is +/// analogous to extracting digits from a number in a mixed-radix system. +fn push_cartesian_product( + id_and_nodes: &[(u64, Vec)], + trace_idx: usize, + jobs: &mut Vec<(usize, Vec)>, + max_slots: &mut usize, +) { + // The product of all dimension sizes gives the total number of + // combinations. If any dimension is empty the product is zero and there + // are no valid combinations to generate. + let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); + if total == 0 { + return; + } + + *max_slots = (*max_slots).max(id_and_nodes.len()); + jobs.reserve(total); + + // Enumerate every combination by treating the combination index `i` as a + // mixed-radix number. The inner loop "peels off" one digit per dimension: + // node_idx = i % nodes.len() — selects this dimension's element + // i /= nodes.len() — shifts to the next dimension's digit + // After processing all dimensions, `i` is exhausted (becomes 0), and + // `combo` contains exactly one entry per instruction id. + for mut i in 0..total { + let mut combo = Vec::with_capacity(id_and_nodes.len()); + for (id, nodes) in id_and_nodes { + let node_idx = i % nodes.len(); + i /= nodes.len(); + let profile = nodes[node_idx]; + combo.push(CombinationEntry { + instruction_id: *id, + node: profile, + }); + } + jobs.push((trace_idx, combo)); + } +} + +#[must_use] +#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] +pub fn estimate_with_graph( + traces: &[&Trace], + graph: &Arc>, + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let max_error = max_error.unwrap_or(1.0); + + // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. + // This reads the provenance graph once per trace and generates the + // cartesian product of Pareto-filtered nodes. Each node carries + // pre-computed (space, time) values for dominance pruning in Phase 2. + let mut jobs: Vec<(usize, Vec)> = Vec::new(); + + // Use the maximum number of instruction slots across all combinations to + // size the pruning witness structure. This will updated while we generate + // jobs. + let mut max_slots = 0; + + // For each trace, collect the set of instruction IDs that affect the + // algorithm runtime (gate instructions from the trace block structure). + // When a trace also has resource states (factories), dominance pruning + // on these slots is unsound because shorter gate time can increase + // factory overhead (see `is_dominated` documentation). + let runtime_affecting_ids: Vec> = traces + .iter() + .map(|trace| { + let has_factories = trace.get_resource_states().is_some_and(|rs| !rs.is_empty()); + if has_factories { + trace.deep_iter().map(|(gate, _)| gate.id).collect() + } else { + FxHashSet::default() + } + }) + .collect(); + + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.base_error() > max_error { + continue; + } + + let required = trace.required_instruction_ids(Some(max_error)); + + let graph_lock = graph.read().expect("Graph lock poisoned"); + let id_and_nodes: Vec<_> = required + .constraints() + .iter() + .filter_map(|constraint| { + graph_lock.pareto_nodes(constraint.id()).map(|nodes| { + ( + constraint.id(), + nodes + .iter() + .filter(|&&node| { + // Filter out nodes that don't meet the constraint bounds. + let instruction = graph_lock.instruction(node); + constraint.error_rate().is_none_or(|c| { + c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) + }) + }) + .map(|&node| { + let instruction = graph_lock.instruction(node); + let space = instruction.space(Some(1)).unwrap_or(0); + let time = instruction.time(Some(1)).unwrap_or(0); + NodeProfile { + node_index: node, + space, + time, + } + }) + .collect::>(), + ) + }) + }) + .collect(); + drop(graph_lock); + + if id_and_nodes.len() != required.len() { + // If any required instruction is missing from the graph, we can't + // run any estimation for this trace. + continue; + } + + push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); + } + + // Sort jobs so that combinations with smaller total (space + time) are + // processed first. This maximises the effectiveness of dominance pruning + // because successful "cheap" combinations establish witnesses that let us + // skip more expensive ones. + jobs.sort_by_key(|(_, combo)| { + combo + .iter() + .map(|entry| entry.node.space + entry.node.time) + .sum::() + }); + + let total_jobs = jobs.len(); + + // Phase 2: Run estimations in parallel with dominance-based pruning. + // + // For each instruction slot in a combination, we track (space, time) + // witnesses from successful estimations keyed by the "context", which is a + // hash of the node indices in all *other* slots. Before running an + // estimation, we check every slot: if a witness with space ≤ and time ≤ + // exists for that context, the combination is dominated and skipped. + let next_job = AtomicUsize::new(0); + + let pruning_witnesses: Vec> = repeat_with(|| { + repeat_with(|| RwLock::new(FxHashMap::default())) + .take(max_slots) + .collect() + }) + .take(traces.len()) + .collect(); + + // There are no explicit ISAs in this estimation function, as we create them + // on the fly from the graph nodes. For successful jobs, we will attach the + // ISAs to the results collection in a vector with the ISA index addressing + // that vector. In order to avoid storing duplicate ISAs we hash the ISA + // index. + let isa_index = Arc::new(RwLock::new(ISAIndex::default())); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + let jobs = &jobs; + let pruning_witnesses = &pruning_witnesses; + let runtime_affecting_ids = &runtime_affecting_ids; + let isa_index = Arc::clone(&isa_index); + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job_idx >= total_jobs { + break; + } + + let (trace_idx, combination) = &jobs[job_idx]; + + // Dominance pruning: skip if a cheaper instruction at any + // slot already succeeded with the same surrounding context. + if is_dominated( + combination, + &pruning_witnesses[*trace_idx], + &runtime_affecting_ids[*trace_idx], + ) { + continue; + } + + let mut isa = ISA::with_graph(graph.clone()); + for entry in combination { + isa.add_node(entry.instruction_id, entry.node.node_index); + } + + if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { + let isa_idx = isa_index + .write() + .expect("RwLock should not be poisoned") + .push(combination, &isa); + result.set_isa_index(isa_idx); + + result.set_trace_index(*trace_idx); + + local_results.push(result); + record_success(combination, &pruning_witnesses[*trace_idx]); + } + } + let _ = tx.send(local_results); + }); + } + drop(tx); + + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + let isa_index = Arc::try_unwrap(isa_index) + .ok() + .expect("all threads joined; Arc refcount should be 1") + .into_inner() + .expect("RwLock should not be poisoned"); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isa_index.isas[idx].clone()); + } + } + + collection.set_isas(isa_index.into()); + + collection +} diff --git a/source/qre/src/trace/instruction_ids.rs b/source/qre/src/trace/instruction_ids.rs new file mode 100644 index 0000000000..f47259ca4a --- /dev/null +++ b/source/qre/src/trace/instruction_ids.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// NOTE: To add a new instruction ID: +// 1. Add it to the `define_instructions!` macro below (primary or alias section) +// 2. Add it to `add_instruction_ids` in qre.rs +// 3. Add it to instruction_ids.pyi +// +// The `instruction_name` function is auto-generated from the primary entries. + +#[cfg(test)] +mod tests; + +/// Macro that defines instruction ID constants and generates the `instruction_name` function. +/// Primary entries are the canonical names returned by `instruction_name`. +/// Aliases are alternative names for the same value. +macro_rules! define_instructions { + ( + primary: [ $( ($name:ident, $value:expr) ),* $(,)? ], + aliases: [ $( ($alias:ident, $avalue:expr) ),* $(,)? ] + ) => { + // Define primary constants + $( + pub const $name: u64 = $value; + )* + + // Define alias constants + $( + pub const $alias: u64 = $avalue; + )* + + /// Returns the canonical name for an instruction ID. + /// For IDs with aliases, returns the primary name. + #[must_use] + pub fn instruction_name(id: u64) -> Option<&'static str> { + match id { + $( + $name => Some(stringify!($name)), + )* + _ => None, + } + } + }; +} + +define_instructions! { + primary: [ + // Paulis + (PAULI_I, 0x0), + (PAULI_X, 0x1), + (PAULI_Y, 0x2), + (PAULI_Z, 0x3), + + // Clifford gates + (H, 0x10), + (H_XY, 0x11), + (H_YZ, 0x12), + (SQRT_X, 0x13), + (SQRT_X_DAG, 0x14), + (SQRT_Y, 0x15), + (SQRT_Y_DAG, 0x16), + (S, 0x17), + (S_DAG, 0x18), + (CNOT, 0x19), + (CY, 0x1A), + (CZ, 0x1B), + (SWAP, 0x1C), + + // State preparation + (PREP_X, 0x30), + (PREP_Y, 0x31), + (PREP_Z, 0x32), + + // Generic Cliffords + (ONE_QUBIT_CLIFFORD, 0x50), + (TWO_QUBIT_CLIFFORD, 0x51), + (N_QUBIT_CLIFFORD, 0x52), + + // Measurements + (MEAS_X, 0x100), + (MEAS_Y, 0x101), + (MEAS_Z, 0x102), + (MEAS_RESET_X, 0x103), + (MEAS_RESET_Y, 0x104), + (MEAS_RESET_Z, 0x105), + (MEAS_XX, 0x106), + (MEAS_YY, 0x107), + (MEAS_ZZ, 0x108), + (MEAS_XZ, 0x109), + (MEAS_XY, 0x10A), + (MEAS_YZ, 0x10B), + + // Non-Clifford gates + (SQRT_SQRT_X, 0x400), + (SQRT_SQRT_X_DAG, 0x401), + (SQRT_SQRT_Y, 0x402), + (SQRT_SQRT_Y_DAG, 0x403), + (T, 0x404), + (T_DAG, 0x405), + (CCX, 0x406), + (CCY, 0x407), + (CCZ, 0x408), + (CSWAP, 0x409), + (AND, 0x40A), + (AND_DAG, 0x40B), + (RX, 0x40C), + (RY, 0x40D), + (RZ, 0x40E), + (CRX, 0x40F), + (CRY, 0x410), + (CRZ, 0x411), + (RXX, 0x412), + (RYY, 0x413), + (RZZ, 0x414), + + // Generic unitaries + (ONE_QUBIT_UNITARY, 0x500), + (TWO_QUBIT_UNITARY, 0x501), + + // Multi-qubit Pauli measurement + (MULTI_PAULI_MEAS, 0x1000), + + // Some generic logical instructions + (LATTICE_SURGERY, 0x1100), + + // Memory/compute operations (used in compute parts of memory-compute layouts) + (READ_FROM_MEMORY, 0x1200), + (WRITE_TO_MEMORY, 0x1201), + (MEMORY, 0x1210), + + // Some special hardware physical instructions + (CYCLIC_SHIFT, 0x1300), + + // Generic operation (for unified RE) + (GENERIC, 0xFFFF), + ], + aliases: [ + // Clifford gate aliases + (H_XZ, 0x10), // alias for H + (SQRT_Z, 0x17), // alias for S + (SQRT_Z_DAG, 0x18), // alias for S_DAG + (CX, 0x19), // alias for CNOT + + // Non-Clifford aliases + (SQRT_SQRT_Z, 0x404), // alias for T + (SQRT_SQRT_Z_DAG, 0x405), // alias for T_DAG + ] +} + +#[must_use] +pub fn is_pauli_measurement(id: u64) -> bool { + matches!( + id, + MEAS_X + | MEAS_Y + | MEAS_Z + | MEAS_XX + | MEAS_YY + | MEAS_ZZ + | MEAS_XZ + | MEAS_XY + | MEAS_YZ + | MULTI_PAULI_MEAS + ) +} + +#[must_use] +pub fn is_t_like(id: u64) -> bool { + matches!( + id, + SQRT_SQRT_X + | SQRT_SQRT_X_DAG + | SQRT_SQRT_Y + | SQRT_SQRT_Y_DAG + | SQRT_SQRT_Z + | SQRT_SQRT_Z_DAG + ) +} + +#[must_use] +pub fn is_ccx_like(id: u64) -> bool { + matches!(id, CCX | CCY | CCZ | CSWAP | AND | AND_DAG) +} + +#[must_use] +pub fn is_rotation_like(id: u64) -> bool { + matches!(id, RX | RY | RZ | RXX | RYY | RZZ) +} + +#[must_use] +pub fn is_clifford(id: u64) -> bool { + matches!( + id, + PAULI_I + | PAULI_X + | PAULI_Y + | PAULI_Z + | H_XZ + | H_XY + | H_YZ + | SQRT_X + | SQRT_X_DAG + | SQRT_Y + | SQRT_Y_DAG + | SQRT_Z + | SQRT_Z_DAG + | CX + | CY + | CZ + | SWAP + | PREP_X + | PREP_Y + | PREP_Z + | ONE_QUBIT_CLIFFORD + | TWO_QUBIT_CLIFFORD + | N_QUBIT_CLIFFORD + ) +} diff --git a/source/qre/src/trace/instruction_ids/tests.rs b/source/qre/src/trace/instruction_ids/tests.rs new file mode 100644 index 0000000000..c9d745ad9a --- /dev/null +++ b/source/qre/src/trace/instruction_ids/tests.rs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn test_instruction_name_primary() { + assert_eq!(instruction_name(H), Some("H")); + assert_eq!(instruction_name(CNOT), Some("CNOT")); + assert_eq!(instruction_name(T), Some("T")); + assert_eq!(instruction_name(MEAS_Z), Some("MEAS_Z")); +} + +#[test] +fn test_instruction_name_aliases_return_primary() { + // Aliases should return the primary name + assert_eq!(instruction_name(H_XZ), Some("H")); + assert_eq!(instruction_name(CX), Some("CNOT")); + assert_eq!(instruction_name(SQRT_Z), Some("S")); + assert_eq!(instruction_name(SQRT_SQRT_Z), Some("T")); +} + +#[test] +fn test_instruction_name_unknown() { + assert_eq!(instruction_name(0x9999), None); +} diff --git a/source/qre/src/trace/tests.rs b/source/qre/src/trace/tests.rs new file mode 100644 index 0000000000..8b31717969 --- /dev/null +++ b/source/qre/src/trace/tests.rs @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[test] +fn test_trace_iteration() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); + trace.add_operation(2, vec![1], vec![]); + + assert_eq!(trace.deep_iter().count(), 2); +} + +#[test] +fn test_nested_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(3); + trace.add_operation(1, vec![0], vec![]); + let block = trace.add_block(2); + block.add_operation(2, vec![1], vec![]); + let block = block.add_block(3); + block.add_operation(3, vec![2], vec![]); + trace.add_operation(1, vec![0], vec![]); + + let repetitions = trace.deep_iter().map(|(_, rep)| rep).collect::>(); + assert_eq!(repetitions.len(), 4); + assert_eq!(repetitions, vec![1, 2, 6, 1]); +} + +#[test] +fn test_depth_simple() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); + trace.add_operation(2, vec![1], vec![]); + + // Operations are parallel + assert_eq!(trace.depth(), 1); + + trace.add_operation(3, vec![0], vec![]); + // Operation on qubit 0 is sequential to first one + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_with_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); // Depth 1 on q0 + + let block = trace.add_block(2); + block.add_operation(2, vec![1], vec![]); // Depth 1 on q1 * 2 reps = 2 + + // Block acts as barrier *only on qubits it touches*. + // q1 is touched. q0 is not. + // q0 stays at depth 1. + // q1 ends at depth 2. + + trace.add_operation(3, vec![0], vec![]); + // Next op starts at depth 1 (after op 1). Ends at 2. + + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_parallel_blocks() { + use crate::trace::Trace; + + let mut trace = Trace::new(4); + + let block1 = trace.add_block(1); + block1.add_operation(1, vec![0], vec![]); // q0: 1 + + let block2 = trace.add_block(1); + block2.add_operation(2, vec![1], vec![]); // q1: 1 + + // Blocks are parallel + assert_eq!(trace.depth(), 1); + + trace.add_operation(3, vec![0, 1], vec![]); + // Dependent on q0 (1) and q1 (1). Start at 1. End at 2. + + assert_eq!(trace.depth(), 2); +} + +#[test] +fn test_depth_entangled() { + use crate::trace::Trace; + + let mut trace = Trace::new(2); + trace.add_operation(1, vec![0], vec![]); // q0: 1 + trace.add_operation(2, vec![1], vec![]); // q1: 1 + + trace.add_operation(3, vec![0, 1], vec![]); // q0, q1 synced at 1 -> end at 2 + + assert_eq!(trace.depth(), 2); + + trace.add_operation(4, vec![0], vec![]); // q0: 3 + assert_eq!(trace.depth(), 3); +} + +#[test] +fn test_psspc_transform() { + use crate::trace::{PSSPC, Trace, TraceTransform, instruction_ids::*}; + + let mut trace = Trace::new(3); + + trace.add_operation(T, vec![0], vec![]); + trace.add_operation(CCX, vec![0, 1, 2], vec![]); + trace.add_operation(RZ, vec![0], vec![0.1]); + trace.add_operation(CX, vec![0, 1], vec![]); + trace.add_operation(RZ, vec![1], vec![0.2]); + trace.add_operation(MEAS_Z, vec![0], vec![]); + + // Configure PSSPC with 20 T states per rotation, include CCX magic states + let psspc = PSSPC::new(20, true); + + let transformed = psspc.transform(&trace).expect("Transformation failed"); + + assert_eq!(transformed.compute_qubits(), 12); + assert_eq!(transformed.depth(), 47); + + assert_eq!(transformed.get_resource_state_count(T), 41); + assert_eq!(transformed.get_resource_state_count(CCX), 1); + + assert!(transformed.base_error() > 0.0); + // Error is roughly 5e-9 for 20 Ts + assert!(transformed.base_error() < 1e-8); +} + +#[test] +fn test_lattice_surgery_transform() { + use crate::trace::{LatticeSurgery, Trace, TraceTransform, instruction_ids::*}; + + let mut trace = Trace::new(3); + + trace.add_operation(T, vec![0], vec![]); + trace.add_operation(CX, vec![1, 2], vec![]); + trace.add_operation(T, vec![0], vec![]); + + assert_eq!(trace.depth(), 2); + + let ls = LatticeSurgery::default(); + let transformed = ls.transform(&trace).expect("Transformation failed"); + + assert_eq!(transformed.compute_qubits(), 3); + assert_eq!(transformed.depth(), 2); + + // Check that we have a LATTICE_SURGERY operation + // TraceIterator visits the operation definition once, but with a multiplier. + let ls_ops: Vec<_> = transformed + .deep_iter() + .filter(|(gate, _)| gate.id == LATTICE_SURGERY) + .collect(); + + assert_eq!(ls_ops.len(), 1); + + let (gate, mult) = ls_ops[0]; + assert_eq!(gate.id, LATTICE_SURGERY); + assert_eq!(mult, 2); // Multiplier should carry the repetition count (depth) +} + +#[test] +fn test_estimate_simple() { + use crate::isa::{Encoding, ISA, Instruction}; + use crate::trace::{Trace, instruction_ids::*}; + + let mut trace = Trace::new(1); + trace.add_operation(T, vec![0], vec![]); + + // Create ISA + let mut isa = ISA::new(); + isa.add_instruction(Instruction::fixed_arity( + T, + Encoding::Logical, + 1, // arity + 100, // time + Some(50), // space + None, // length (defaults to arity) + 0.001, // error_rate + )); + + let result = trace.estimate(&isa, None).expect("Estimation failed"); + + assert!((result.error() - 0.001).abs() <= f64::EPSILON); + assert_eq!(result.runtime(), 100); + assert_eq!(result.qubits(), 50); +} + +#[test] +fn test_estimate_with_factory() { + use crate::isa::{Encoding, ISA, Instruction}; + use crate::trace::{Trace, instruction_ids::*}; + + let mut trace = Trace::new(1); + // Algorithm needs 1000 T states + trace.increment_resource_state(T, 1000); + + // Some compute runtime to allow factories to run + trace.add_operation(GENERIC, vec![0], vec![]); + + let mut isa = ISA::new(); + + // T factory instruction + // Produces 1 T state + isa.add_instruction(Instruction::fixed_arity( + T, + Encoding::Logical, + 1, // arity + 10, // time to produce 1 state + Some(50), // space for factory + None, + 0.0001, // error rate of produced state + )); + + isa.add_instruction(Instruction::fixed_arity( + GENERIC, + Encoding::Logical, + 1, + 1000, // runtime 1000 + Some(200), + None, + 0.0, + )); + + let result = trace.estimate(&isa, None).expect("Estimation failed"); + + assert_eq!(result.runtime(), 1000); + assert_eq!(result.qubits(), 700); + + // Check factory result + let factory_res = result.factories().get(&T).expect("Factory missing"); + assert_eq!(factory_res.copies(), 10); + assert_eq!(factory_res.runs(), 100); + assert_eq!(result.factories().len(), 1); +} + +#[test] +fn test_trace_display_uses_instruction_names() { + use crate::trace::Trace; + use crate::trace::instruction_ids::{CNOT, H, MEAS_Z}; + + let mut trace = Trace::new(2); + trace.add_operation(H, vec![0], vec![]); + trace.add_operation(CNOT, vec![0, 1], vec![]); + trace.add_operation(MEAS_Z, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains('H'), + "Expected 'H' in trace output: {display}" + ); + assert!( + display.contains("CNOT"), + "Expected 'CNOT' in trace output: {display}" + ); + assert!( + display.contains("MEAS_Z"), + "Expected 'MEAS_Z' in trace output: {display}" + ); +} + +#[test] +fn test_trace_display_unknown_instruction() { + use crate::trace::Trace; + + let mut trace = Trace::new(1); + trace.add_operation(0x9999, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains("??"), + "Expected '??' for unknown instruction in: {display}" + ); +} + +#[test] +fn test_block_display_with_repetitions() { + use crate::trace::Trace; + use crate::trace::instruction_ids::H; + + let mut trace = Trace::new(1); + let block = trace.add_block(10); + block.add_operation(H, vec![0], vec![]); + + let display = format!("{trace}"); + + assert!( + display.contains("repeat 10"), + "Expected 'repeat 10' in: {display}" + ); + assert!(display.contains('H'), "Expected 'H' in block: {display}"); +} diff --git a/source/qre/src/trace/transforms.rs b/source/qre/src/trace/transforms.rs new file mode 100644 index 0000000000..d232ba5b9f --- /dev/null +++ b/source/qre/src/trace/transforms.rs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod lattice_surgery; +mod psspc; + +pub use lattice_surgery::LatticeSurgery; +pub use psspc::PSSPC; + +use crate::{Error, Trace}; + +pub trait TraceTransform { + fn transform(&self, trace: &Trace) -> Result; +} diff --git a/source/qre/src/trace/transforms/lattice_surgery.rs b/source/qre/src/trace/transforms/lattice_surgery.rs new file mode 100644 index 0000000000..fd3ff45f72 --- /dev/null +++ b/source/qre/src/trace/transforms/lattice_surgery.rs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::trace::TraceTransform; +use crate::{Error, Trace, instruction_ids}; + +pub struct LatticeSurgery { + slow_down_factor: f64, +} + +impl Default for LatticeSurgery { + fn default() -> Self { + Self { + slow_down_factor: 1.0, + } + } +} + +impl LatticeSurgery { + #[must_use] + pub fn new(slow_down_factor: f64) -> Self { + Self { slow_down_factor } + } +} + +impl TraceTransform for LatticeSurgery { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss + )] + fn transform(&self, trace: &Trace) -> Result { + let mut transformed = trace.clone_empty(None); + + let block = + transformed.add_block((trace.depth() as f64 * self.slow_down_factor).ceil() as u64); + block.add_operation( + instruction_ids::LATTICE_SURGERY, + (0..trace.compute_qubits()).collect(), + vec![], + ); + + Ok(transformed) + } +} diff --git a/source/qre/src/trace/transforms/psspc.rs b/source/qre/src/trace/transforms/psspc.rs new file mode 100644 index 0000000000..80ec36bd99 --- /dev/null +++ b/source/qre/src/trace/transforms/psspc.rs @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::property_keys::NUM_TS_PER_ROTATION; +use crate::trace::{Gate, TraceTransform}; +use crate::{Error, Property, Trace, instruction_ids}; + +/// Implements the Parellel Synthesis Sequential Pauli Computation (PSSPC) +/// layout algorithm described in Appendix D in +/// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629). This scheme combines +/// sequential Pauli-based computation (SPC) as described in +/// [arXiv:1808.02892](https://arxiv.org/pdf/1808.02892) and +/// [arXiv:2109.02746](https://arxiv.org/pdf/2109.02746) with an approach to +/// synthesize sets of diagonal non-Clifford unitaries in parallel as done in +/// [arXiv:2110.11493](https://arxiv.org/pdf/2110.11493). +/// +/// References: +/// - Michael E. Beverland, Prakash Murali, Matthias Troyer, Krysta M. Svore, +/// Torsten Hoefler, Vadym Kliuchnikov, Guang Hao Low, Mathias Soeken, Aarthi +/// Sundaram, Alexander Vaschillo: Assessing requirements to scale to +/// practical quantum advantage, +/// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) +/// - Daniel Litinski: A Game of Surface Codes: Large-Scale Quantum Computing +/// with Lattice Surgery, [arXiv:1808.02892](https://arxiv.org/pdf/1808.02892) +/// - Christopher Chamberland, Earl T. Campbell: Universal quantum computing +/// with twist-free and temporally encoded lattice surgery, +/// [arXiv:2109.02746](https://arxiv.org/pdf/2109.02746) +/// - Michael Beverland, Vadym Kliuchnikov, Eddie Schoute: Surface code +/// compilation via edge-disjoint paths, +/// [arXiv:2110.11493](https://arxiv.org/pdf/2110.11493). +#[derive(Clone)] +pub struct PSSPC { + /// Number of multi-qubit Pauli measurements to inject a synthesized + /// rotation, defaults to 1, see [arXiv:2211.07629, (D3)] + num_measurements_per_r: u64, + /// Number of multi-qubit Pauli measurements to apply a Toffoli gate, + /// defaults to 3, see [arXiv:2211.07629, (D3)] + num_measurements_per_ccx: u64, + /// Number of Pauli measurements to write to memory, defaults to 2, see + /// [arXiv:2109.02746, Fig. 16a] + num_measurements_per_wtm: u64, + /// Number of Pauli measurements to read from memory, defaults to 1, see + /// [arXiv:2109.02746, Fig. 16b] + num_measurements_per_rfm: u64, + + /// Number of Ts per rotation synthesis + num_ts_per_rotation: u64, + /// Perform Toffoli gates using CCX magic states, if false, T gates are used + ccx_magic_states: bool, +} + +impl PSSPC { + #[must_use] + pub fn new(num_ts_per_rotation: u64, ccx_magic_states: bool) -> Self { + Self { + num_measurements_per_r: 1, + num_measurements_per_ccx: 3, + num_measurements_per_wtm: 2, + num_measurements_per_rfm: 1, + num_ts_per_rotation, + ccx_magic_states, + } + } +} + +impl PSSPC { + #[allow(clippy::cast_possible_truncation)] + fn psspc_counts(trace: &Trace) -> Result { + let mut counter = PSSPCCounts::default(); + + let mut max_rotation_depth = vec![0; trace.total_qubits() as usize]; + + for (Gate { id, qubits, .. }, mult) in trace.deep_iter() { + if instruction_ids::is_pauli_measurement(*id) { + counter.measurements += mult; + } else if instruction_ids::is_t_like(*id) { + counter.t_like += mult; + } else if instruction_ids::is_ccx_like(*id) { + counter.ccx_like += mult; + } else if instruction_ids::is_rotation_like(*id) { + counter.rotation_like += mult; + + // Track rotation depth + let mut current_depth = 0; + for q in qubits { + if max_rotation_depth[*q as usize] > current_depth { + current_depth = max_rotation_depth[*q as usize]; + } + } + let new_depth = current_depth + mult; + for q in qubits { + max_rotation_depth[*q as usize] = new_depth; + } + if new_depth > counter.rotation_depth { + counter.rotation_depth = new_depth; + } + } else if *id == instruction_ids::READ_FROM_MEMORY { + counter.read_from_memory += mult; + } else if *id == instruction_ids::WRITE_TO_MEMORY { + counter.write_to_memory += mult; + } else if !instruction_ids::is_clifford(*id) { + // Unsupported non-Clifford gate + return Err(Error::UnsupportedInstruction { + id: *id, + name: "PSSPC", + }); + } else { + // For Clifford gates, synchronize depths across qubits + if !qubits.is_empty() { + let mut max_depth = 0; + for q in qubits { + if max_rotation_depth[*q as usize] > max_depth { + max_depth = max_rotation_depth[*q as usize]; + } + } + for q in qubits { + max_rotation_depth[*q as usize] = max_depth; + } + } + } + } + + Ok(counter) + } + + #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap)] + fn get_trace(&self, trace: &Trace, counts: &PSSPCCounts) -> Trace { + let num_qubits = trace.compute_qubits(); + let logical_qubits = Self::logical_qubit_overhead(num_qubits); + + let mut transformed = trace.clone_empty(Some(logical_qubits)); + + let logical_depth = self.logical_depth_overhead(counts); + let (t_states, ccx_states) = self.num_magic_states(counts); + + transformed.increment_resource_state(instruction_ids::T, t_states); + transformed.increment_resource_state(instruction_ids::CCX, ccx_states); + + let block = transformed.add_block(logical_depth); + block.add_operation( + instruction_ids::MULTI_PAULI_MEAS, + (0..logical_qubits).collect(), + vec![], + ); + + // Add error due to rotation synthesis + transformed.increment_base_error(counts.rotation_like as f64 * self.synthesis_error()); + + // Track some properties + transformed.set_property( + NUM_TS_PER_ROTATION, + Property::Int(self.num_ts_per_rotation as i64), + ); + + transformed + } + + /// Calculates the number of logical qubits required for the PSSPC layout + /// according to Eq. (D1) in + /// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + #[allow( + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss + )] + fn logical_qubit_overhead(algorithm_qubits: u64) -> u64 { + let qubit_padding = ((8 * algorithm_qubits) as f64).sqrt().ceil() as u64 + 1; + 2 * algorithm_qubits + qubit_padding + } + + /// Calculates the number of multi-qubit Pauli measurements executed in + /// sequence according to Eq. (D3) in + /// [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + fn logical_depth_overhead(&self, counter: &PSSPCCounts) -> u64 { + (counter.measurements + counter.t_like + counter.rotation_like) + * self.num_measurements_per_r + + counter.ccx_like * self.num_measurements_per_ccx + + counter.read_from_memory * self.num_measurements_per_rfm + + counter.write_to_memory * self.num_measurements_per_wtm + + (self.num_ts_per_rotation * counter.rotation_depth) * self.num_measurements_per_r + } + + /// Calculates the number of T and CCX magic states that are consumed by + /// multi-qubit Pauli measurements executed by PSSPC according to Eq. (D4) + /// in [arXiv:2211.07629](https://arxiv.org/pdf/2211.07629) + /// + /// CCX magic states are only counted if the hyper parameter + /// `ccx_magic_states` is set to true. + fn num_magic_states(&self, counter: &PSSPCCounts) -> (u64, u64) { + let t_states = counter.t_like + self.num_ts_per_rotation * counter.rotation_like; + + if self.ccx_magic_states { + (t_states, counter.ccx_like) + } else { + (t_states + 4 * counter.ccx_like, 0) + } + } + + /// Calculates the synthesis error from the formula provided in Table 1 in + /// [arXiv:2203.10064](https://arxiv.org/pdf/2203.10064) for Clifford+T in + /// the mixed fallback approximation protocol. + #[allow(clippy::cast_precision_loss)] + fn synthesis_error(&self) -> f64 { + 2f64.powf((4.86 - self.num_ts_per_rotation as f64) / 0.53) + } +} + +impl TraceTransform for PSSPC { + fn transform(&self, trace: &Trace) -> Result { + let counts = Self::psspc_counts(trace)?; + + Ok(self.get_trace(trace, &counts)) + } +} + +#[derive(Default)] +struct PSSPCCounts { + measurements: u64, + t_like: u64, + ccx_like: u64, + rotation_like: u64, + write_to_memory: u64, + read_from_memory: u64, + rotation_depth: u64, +} diff --git a/source/qre/src/utils.rs b/source/qre/src/utils.rs new file mode 100644 index 0000000000..ffa82b04d4 --- /dev/null +++ b/source/qre/src/utils.rs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use probability::prelude::{Binomial, Inverse}; + +#[allow(clippy::doc_markdown)] +/// Faster implementation of SciPy's binom.ppf +#[must_use] +pub fn binom_ppf(q: f64, n: usize, p: f64) -> usize { + let dist = Binomial::with_failure(n, 1.0 - p); + dist.inverse(q) +} + +#[must_use] +pub fn float_to_bits(f: f64) -> u64 { + f.to_bits() +} + +#[must_use] +pub fn float_from_bits(b: u64) -> f64 { + f64::from_bits(b) +}